From f59b8c7d3db6a78685d7330d377cb8095c359434 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Thu, 12 Jun 2025 09:46:33 -0700 Subject: [PATCH 001/103] OCP FP8 Macro restructure (#2331) * solved the problem --- include/ck_tile/core/config.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 14b33aea77..1ecc28fbeb 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -240,17 +240,17 @@ #define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1 #endif -#ifndef __HIP_DEVICE_COMPILE__ // for host code -#ifdef CK_TILE_USE_OCP_FP8 +#ifndef CK_TILE_USE_OCP_FP8 +#if defined(__HIP_DEVICE_COMPILE__) +#if defined(__gfx950__) || defined(__gfx12__) #define CK_TILE_USE_OCP_FP8 1 #else #define CK_TILE_USE_OCP_FP8 0 #endif -#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code -#define CK_TILE_USE_OCP_FP8 1 -#else // for GPU code +#else #define CK_TILE_USE_OCP_FP8 0 #endif +#endif #ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN #if __clang_major__ == 20 From e5ece1446782b99877792d51e4ed3119dfd7000a Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 12 Jun 2025 18:27:14 -0400 Subject: [PATCH 002/103] fix(gemm_universal): Update gemm_utils.hpp so it builds successfully for memory pipeline (#2336) --- example/ck_tile/03_gemm/gemm_utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index aec5f6a116..cd4ace6d2f 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -49,7 +49,7 @@ struct GemmConfig static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 8; + static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; #endif From 5f1ad09b610cb0e083f63988479ab022bda70588 Mon Sep 17 00:00:00 2001 From: kylasa Date: Thu, 12 Jun 2025 18:24:02 -0700 Subject: [PATCH 003/103] Code drop for 2 warp ping pong scheduler along K dimension. (#2276) * Code drop for 2 warp ping pong scheduler along K dimension. * Addressing code review comments. * Addressing Clang formatting issues. * Addressing build issues. * Addressing build issues of other GEMM pipelines with ping pong scheduler code drop. * Fix for LDS memory size for GEMM pipelines. * Addressing code review feedback comments. * Change log update. * Addressing code review comments and build issues. * Added new policy for pipeline specific logic about LDS needs. * Clang Fix during build. --- CHANGELOG.md | 1 + example/ck_tile/03_gemm/gemm_utils.hpp | 35 +- example/ck_tile/03_gemm/universal_gemm.cpp | 8 +- .../algorithm/static_encoding_pattern.hpp | 92 +++-- .../ops/epilogue/cshuffle_epilogue.hpp | 7 +- include/ck_tile/ops/gemm.hpp | 2 + .../block/block_gemm_areg_breg_creg_v1.hpp | 160 ++++++-- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 17 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 10 +- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 1 + .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 1 + .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 379 ++++++++++++++++++ ...peline_ag_bg_cr_comp_v5_default_policy.hpp | 63 +++ .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 1 + .../gemm/pipeline/gemm_pipeline_problem.hpp | 2 + ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 54 ++- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 4 +- 17 files changed, 727 insertions(+), 110 deletions(-) create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index aecf16d83d..af8d965b30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. * Added benchmarking support for tile engine GEMM. +* Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. ### Optimized diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index cd4ace6d2f..f3d11c751b 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -14,6 +14,7 @@ #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 #ifndef CK_TILE_PIPELINE_DEFAULT #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 @@ -31,6 +32,10 @@ #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV5 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV5 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave #else #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif @@ -51,7 +56,8 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t NumWaveGroups = 1; #endif #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) // Compute friendly for Intrawave scheduler @@ -67,7 +73,8 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t NumWaveGroups = 1; #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) // Compute friendly for Intrawave scheduler // Using the ping pong reader in the lds level @@ -83,7 +90,29 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t NumWaveGroups = 1; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + + // Available wavegroups will be split into `NumWaveGroups` and each of these wavegroups + // will be responsible for specific jobs. For instance, perform Global Memory read operations, + // perform block-gemm operation etc... + static constexpr ck_tile::index_t NumWaveGroups = 2; #endif static constexpr bool kPadM = false; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 3a7cc93df8..fafe40c333 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -50,7 +50,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& CLayout, GemmConfig::TransposeC, GemmConfig::UseStructuredSparsity, - Persistent>; + Persistent, + GemmConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -96,7 +97,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC, - memory_operation>>; + memory_operation, + GemmConfig::NumWaveGroups>>; + using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -190,7 +193,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& }; BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; } diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index b56bda3741..d8a8f6ab66 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -56,19 +56,24 @@ template + tile_distribution_pattern DistributionPattern, + index_t NumWaveGroups = 1> struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern { }; // Thread raked -template +template struct TileDistributionEncodingPattern2D - : public TileDistributionEncodingPattern + tile_distribution_pattern::thread_raked, + NumWaveGroups> : public TileDistributionEncodingPattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! @@ -83,45 +88,76 @@ struct TileDistributionEncodingPattern2D, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<2, 1>>{}); + if constexpr(NumWaveGroups != 1) + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + else + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } } CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution() { - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<1, 2>>{}); + if constexpr(NumWaveGroups != 1) + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + else + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<1, 2>>{}); + } } }; // Warp raked -template +template struct TileDistributionEncodingPattern2D - : public TileDistributionEncodingPattern + tile_distribution_pattern::warp_raked, + NumWaveGroups> : public TileDistributionEncodingPattern { static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -164,13 +200,17 @@ struct TileDistributionEncodingPattern2D +template struct TileDistributionEncodingPattern2D - : public TileDistributionEncodingPattern + tile_distribution_pattern::block_raked, + NumWaveGroups> : public TileDistributionEncodingPattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 5a6521deb5..6613ceebb2 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -23,7 +23,8 @@ template + memory_operation_enum MemoryOperation_, + index_t kNumWaveGroups_ = 1> struct CShuffleEpilogueProblem { using ADataType = remove_cvref_t; @@ -41,6 +42,7 @@ struct CShuffleEpilogueProblem static constexpr index_t KPerXdl = KPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; + static constexpr index_t kNumWaveGroups = kNumWaveGroups_; }; template @@ -236,7 +238,8 @@ struct CShuffleEpilogue MPerIterationShuffle, NPerIterationShuffle, GetVectorSizeC(), - tile_distribution_pattern::thread_raked>; + tile_distribution_pattern::thread_raked, + Problem::kNumWaveGroups>; constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); constexpr auto c_warp_y_lengths = diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 35f5170179..8db822ebd1 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,6 +31,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index b4362d9069..28d8b3eead 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -60,52 +60,105 @@ struct BlockGemmARegBRegCRegV1 static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; - static constexpr index_t MWarp = Traits::MWarp; - static constexpr index_t NWarp = Traits::NWarp; + static constexpr index_t MWarp = Traits::MWarp; + static constexpr index_t NWarp = Traits::NWarp; + static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1); CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; - return a_block_dstr_encode; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + else + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } } CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; + return b_block_dstr_encode; + } + else + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } } CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encode; + return c_block_dstr_encode; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } } // C += A * B @@ -201,19 +254,38 @@ struct BlockGemmARegBRegCRegV1 CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + if constexpr(UseDefaultScheduler) + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } } // C = A * B diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index edcde4a09f..bfb0d2626b 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -644,6 +644,7 @@ struct GemmKernel * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * */ + template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, @@ -671,11 +672,15 @@ struct GemmKernel const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, smem_ptr_0); + } } /** @@ -772,7 +777,9 @@ struct GemmKernel EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { - RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); + RunGemm( + a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); } } } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 24bd66a59e..07bfb33252 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -71,7 +71,8 @@ struct GemmPipelineAgBgCrImplBase template CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, const ALdsTensorView& a_lds_block_view, - const ALdsLoadTileDistr&) const + const ALdsLoadTileDistr&, + const array& offset = {0, 0}) const { constexpr bool is_col_major = std::is_same_v; @@ -82,7 +83,7 @@ struct GemmPipelineAgBgCrImplBase auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(YPerTile{}, XPerTile{}), - a_dram_block_window_tmp.get_window_origin(), + a_dram_block_window_tmp.get_window_origin() + offset, Policy::template MakeADramTileDistribution()); // A LDS tile window for store @@ -103,7 +104,8 @@ struct GemmPipelineAgBgCrImplBase template CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, const BLdsTensorView& b_lds_block_view, - const BLdsLoadTileDistr&) const + const BLdsLoadTileDistr&, + const array& offset = {0, 0}) const { constexpr bool is_row_major = std::is_same_v; @@ -113,7 +115,7 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_dram_window = make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(YPerTile{}, XPerTile{}), - b_dram_block_window_tmp.get_window_origin(), + b_dram_block_window_tmp.get_window_origin() + offset, Policy::template MakeBDramTileDistribution()); // TODO: Do we really need those two tile windows??? diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index a6267e4c89..eb47d9bad6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -143,6 +143,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr bool kPadK = Problem::kPadK; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 6fc6ba2ba2..8424c43e86 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -134,6 +134,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr bool kPadK = Problem::kPadK; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp new file mode 100644 index 0000000000..9ef7f3f0ef --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -0,0 +1,379 @@ +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed Tensor: register + +template +struct BaseGemmPipelineAgBgCrCompV5 +{ + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } + + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) + { + return TailNumber::Empty; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) + { + return run_func(bool_constant{}, integral_constant{}); + } +}; + +template +struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 +{ + using Base = BaseGemmPipelineAgBgCrCompV5; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + static constexpr index_t NumWarps = BlockGemmShape::NumWarps; + static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{}); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AgBgCrCompV5", BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem_0) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "Data Type conflict on A and B matrix input data type."); + + static_assert( + KPerBlock % ((NumWarps / 2) * KTileSize) == 0, + "Ping Pong Warps, TileSize and Block Size for K dimensions does not match."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + index_t warp_id = get_warp_id(); + index_t operation_id = + __builtin_amdgcn_readfirstlane(get_warp_id()); // 0 - Memory read, 1 - block-gemm + + auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock); + auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock); + + auto tensor_views = + Base::GetABLdsTensorViews(static_cast(static_cast(p_smem_0))); + auto& a_lds_block = tensor_views.get(number<0>{}); + auto& b_lds_block = tensor_views.get(number<1>{}); + + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto a_windows = Base::GetAWindows( + a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr, a_offset); + auto& a_copy_dram_window = a_windows.get(number<0>{}); + auto& a_copy_lds_window = a_windows.get(number<1>{}); + auto& a_lds_window = a_windows.get(number<2>{}); + + auto b_windows = Base::GetBWindows( + b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr, b_offset); + auto& b_copy_dram_window = b_windows.get(number<0>{}); + auto& b_copy_lds_window = b_windows.get(number<1>{}); + auto& b_lds_window = b_windows.get(number<2>{}); + + // DRAM window steps. + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock * NumWarps, 0) + : make_array(0, KPerBlock * NumWarps); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock * NumWarps, 0) + : make_array(0, KPerBlock * NumWarps); + + constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using AGemmTile = decltype(make_static_distributed_tensor(AGemmTileDistr)); + using BGemmTile = decltype(make_static_distributed_tensor(BGemmTileDistr)); + AGemmTile a_tile_0, a_tile_1; + BGemmTile b_tile_0, b_tile_1; + + // Register tile for A and B. + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + ABlockTile a_global_load_tile; + BBlockTile b_global_load_tile; + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile_0 = block_gemm.MakeCBlockTile(); + auto c_block_tile_1 = block_gemm.MakeCBlockTile(); + + CDataType* __restrict__ p_c_lds = static_cast(p_smem_0); + auto c_lds_block_0 = + make_naive_tensor_view(p_c_lds, + make_tuple(MPerBlock, NPerBlock), + make_tuple(NPerBlock, 1), + number{}, + number<1>{}); + auto c_window_0 = make_tile_window(c_lds_block_0, + make_tuple(number{}, number{}), + {0, 0}, + c_block_tile_1.get_tile_distribution()); + + // initialize C + if(warp_id == 0) + { + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0); + } + else + { + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_1); + } + + // define ping, pong steps here as lambda functions. + auto MemoryOpsStep = [&](auto idx) { + // Memory read half here. + Base::GlobalPrefetch( + a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_global_load_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func); + } + + if(idx == 0) + { + Base::LocalPrefetch(a_tile_0, a_lds_window); + Base::LocalPrefetch(b_tile_0, b_lds_window); + } + else + { + Base::LocalPrefetch(a_tile_1, a_lds_window); + Base::LocalPrefetch(b_tile_1, b_lds_window); + } + }; + + auto ComputeStep = [&](auto idx) { + if(idx == 0) + { + block_gemm(c_block_tile_0, a_tile_0, b_tile_0); + } + else + { + block_gemm(c_block_tile_1, a_tile_1, b_tile_1); + } + }; + + if(operation_id == 0) + { + MemoryOpsStep(warp_id); + } + + index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop); + while(num_compute_steps > 1) + { + block_sync_lds(); + operation_id = (operation_id + 1) % NumWaveGroups; + + if(operation_id == 0) + { + MemoryOpsStep(warp_id); + } + else + { + ComputeStep(warp_id); + } + num_compute_steps -= 1; + } + block_sync_lds(); + + if(operation_id == 0) + { + ComputeStep(warp_id); + } + block_sync_lds(); + + if(warp_id == 1) + { + store_tile(c_window_0, c_block_tile_1); + } + block_sync_lds(); + + if(warp_id == 0) + { + load_tile(c_block_tile_1, c_window_0); + + constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + auto idx2 = make_tuple(idx0, idx1); + c_block_tile_0(idx2) += c_block_tile_1(idx2); + }); + }); + } + return c_block_tile_0; + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0); + } + + public: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem_0); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp new file mode 100644 index 0000000000..c03db08c3f --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { +// Default policy for GemmPipelineAGmemBGmemCregComputeV5, except the block gemm method, it shares +// the same vector size implementation, SmemSize, Global memory tile distiribution as the +// UniversalGemm Pipeline Policy. +// Default policy class should not be templated, put template on +// member functions instead. +struct GemmPipelineAgBgCrCompV5DefaultPolicy + : public UniversalGemmBasePolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_DEVICE static constexpr index_t GetSmemSizeC() + { + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + + return integer_least_multiple(sizeof(typename Problem::CDataType) * MPerBlock * NPerBlock, + 16); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + constexpr index_t smem_size_c = GetSmemSizeC(); + + return smem_size_a + smem_size_b >= smem_size_c ? (smem_size_a + smem_size_b) + : (smem_size_c); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index f7b5f9b3cb..1f2ab80797 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -188,6 +188,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr bool kPadK = Problem::kPadK; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Where is the right place for HasHotLoop and TailNum ??? static constexpr bool HasHotLoop = Problem::HasHotLoop; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 0b38e7789e..678fb6eb46 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -198,6 +198,8 @@ struct UniversalGemmPipelineProblem static constexpr bool TransposeC = Traits::TransposeC; static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; + + static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 6890cf2f64..91e845d200 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -426,10 +426,11 @@ struct UniversalGemmBasePolicy { using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) @@ -438,7 +439,8 @@ struct UniversalGemmBasePolicy MPerBlock, KPerBlock, VecLoadSize, - ATileAccessPattern>; + ATileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::Make2DStaticTileDistribution(); } // Tile: KPerBlock X MPerBlock @@ -448,7 +450,8 @@ struct UniversalGemmBasePolicy KPerBlock, MPerBlock, VecLoadSize, - ATileAccessPattern>; + ATileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::Make2DStaticTileDistribution(); } } @@ -458,10 +461,11 @@ struct UniversalGemmBasePolicy { using BLayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) @@ -470,7 +474,8 @@ struct UniversalGemmBasePolicy KPerBlock, NPerBlock, VecLoadSize, - BTileAccessPattern>; + BTileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::Make2DStaticTileDistribution(); } // Tile: NPerBlock X KPerBlock @@ -480,7 +485,8 @@ struct UniversalGemmBasePolicy NPerBlock, KPerBlock, VecLoadSize, - BTileAccessPattern>; + BTileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::Make2DStaticTileDistribution(); } } @@ -490,16 +496,18 @@ struct UniversalGemmBasePolicy { using ALayout = remove_cvref_t; static_assert(std::is_same_v); - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using TileEncodingPattern = TileDistributionEncodingPattern2D; + ATileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } @@ -508,16 +516,18 @@ struct UniversalGemmBasePolicy { using BLayout = remove_cvref_t; static_assert(std::is_same_v); - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using TileEncodingPattern = TileDistributionEncodingPattern2D; + BTileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index a61b0eee3c..353192d86f 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -39,7 +39,8 @@ template + bool UsePersistentKernel_ = false, + index_t NumWaveGroups_ = 1> struct TileGemmUniversalTraits { static constexpr bool kPadM = kPadM_; @@ -55,6 +56,7 @@ struct TileGemmUniversalTraits static constexpr bool TransposeC = TransposeC_; static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; static constexpr bool UsePersistentKernel = UsePersistentKernel_; + static constexpr index_t NumWaveGroups = NumWaveGroups_; }; template Date: Fri, 13 Jun 2025 03:58:50 -0700 Subject: [PATCH 004/103] Shard several of the most costly targets. (#2266) * Shard several of the most costly targets. Introduces a filter_tuple_by_modulo to break up tuples. Drops build time of target from 21 minutes to under 14 minutes with 64 build processes, or 11 minutes with 128 build processes. time ninja -j 64 device_grouped_conv3d_fwd_instance * fix clang format * Fix build errors in instantiation code. I wasn't sure how to test the header-only instantiation code on my initial commit. From Jenkins CI test results, I see that there is a test target that depends on these headers: ninja -j 128 test_grouped_convnd_fwd This allowed me to test the build locally. I found three mistakes I made, mostly related to early experiments on I tried on the code. This was hard to find earlier because this PR is really too large. I also discovered that there are five 2D convolution targets that now dominate the compilation time. I will likely address those in a later PR, rather than adding even more changes to this PR. * Fix link errors from mismatched declarations. Our pattern for instantiating MIOpen templates uses duplicate declarations (instead of headers). This is fragile, and I didn't notice that my last commit had a bunch of link errors. I fixed these mistakes, and the bin/test_grouped_conv_fwd test target binary now links correctly. * Migrate the design to a code-generation approach. Use a CMake function with template files to generate the source files for the intantiating the kerenels and to generate the calling function. * Shard the longest 2D convolution builds Now that we have automated the shard instantiation, we can shard the 2D convolution targets that take the longest to build. The target test_grouped_conv2d_fwd now compiles in 15 minutes. * Use PROJECT_SOURCE_DIR for submodule compatibility I used CMAKE_SOURCE_DIR to refer to the top-level source directory in the ShardInstantiation.cmake file, but this can cause issues with git submodules. Instead, we should use PROJECT_SOURCE_DIR to ensure compatibility when this project is used as a submodule in another project. --------- Co-authored-by: illsilin --- .gitignore | 3 + cmake/ShardInstantiation.cmake | 116 ++++++++++++++++++ cmake/call_shard.in | 15 +++ cmake/instantiate_shard.in | 9 ++ include/ck/utility/filter_tuple.hpp | 66 ++++++++++ .../gpu/grouped_convolution_forward_xdl.inc | 3 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 51 +++++++- ...l_ngchw_gkcyx_ngkhw_bf16_comp_instance.in} | 38 +++--- ...wd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in} | 40 +++--- ...fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in} | 64 +++++----- ...gc_gkyxc_nhwgk_int8_mem_inter_instance.cpp | 66 ---------- ...wgc_gkyxc_nhwgk_int8_mem_inter_instance.in | 80 ++++++++++++ ...gc_gkyxc_nhwgk_int8_mem_intra_instance.cpp | 66 ---------- ...wgc_gkyxc_nhwgk_int8_mem_intra_instance.in | 80 ++++++++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 109 +++++++++++++--- ...dhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp | 111 ----------------- ...ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in | 66 ++++++++++ ...ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp | 111 ----------------- ..._ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in | 65 ++++++++++ ...gcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp | 54 -------- ...ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in | 65 ++++++++++ ...ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp | 54 -------- ..._ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in | 63 ++++++++++ ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp | 53 -------- ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in} | 53 ++++---- ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp | 53 -------- ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in} | 53 ++++---- ...ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp | 9 ++ ...ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp | 9 ++ ...ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp | 9 ++ ...ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp | 9 ++ ...ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp | 9 ++ ...ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp | 9 ++ ...ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp | 9 ++ ...ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp | 9 ++ ...w_gkczyx_ngkdhw_bf16_mem_inter_instance.in | 64 ++++++++++ ...w_gkczyx_ngkdhw_bf16_mem_intra_instance.in | 65 ++++++++++ ...w_gkczyx_ngkdhw_f16_mem_inter_instance.in} | 69 ++++++----- ...w_gkczyx_ngkdhw_f16_mem_intra_instance.in} | 75 ++++++----- ...w_gkczyx_ngkdhw_f32_mem_inter_instance.in} | 69 ++++++----- ...w_gkczyx_ngkdhw_f32_mem_intra_instance.in} | 69 ++++++----- ...w_gkczyx_ngkdhw_f32_mem_intra_instance.inc | 62 ++++++++++ 42 files changed, 1325 insertions(+), 827 deletions(-) create mode 100644 cmake/ShardInstantiation.cmake create mode 100644 cmake/call_shard.in create mode 100644 cmake/instantiate_shard.in create mode 100644 include/ck/utility/filter_tuple.hpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.in} (53%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in} (71%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in} (64%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/{mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in} (64%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/{mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in} (64%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in} (59%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in} (57%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in} (59%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in} (59%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.inc diff --git a/.gitignore b/.gitignore index 599ef99e35..e4dd8f7513 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,6 @@ build*/ # Python cache __pycache__/ + +.cache/ + diff --git a/cmake/ShardInstantiation.cmake b/cmake/ShardInstantiation.cmake new file mode 100644 index 0000000000..47a5d0c48c --- /dev/null +++ b/cmake/ShardInstantiation.cmake @@ -0,0 +1,116 @@ +# Function to generate templated instantiation functions and caller function. + +# In order to reduce build times, we split the instantiation of template functions into multiple files. +# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions, +# which can be placed the TEMPLATE_FILE (typically a .in file). + +# This CMake function generates the instantiation functions and a caller function that calls all the instantiation +# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of +# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard, +# and generates a caller function that calls all the instantiation functions. + +# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation +# of the template functions in the caller function, and that code is automatically generated by this function. + +# In addition to the user-supplied template, this CMake function uses two generic templates: +# +# 1. `instantiate_shard.in`: This is the template for the instantiation functions. +# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions. + +# This function takes the following arguments: +# +# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`). +# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions. +# - NUM_SHARDS: The number of shards to generate. +# - OUTPUT_DIR: The build directory where the generated source files will be placed. +# - SRC_LIST: The list of source files to which the generated source files will be added. + + +function(generate_sharded_instantiations) + cmake_parse_arguments( + GEN_SHARDED + # No boolean arguments + "" + # Single-value arguments + "INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST" + # No multi-value arguments. + "" + ${ARGN} + ) + if (NOT GEN_SHARDED_INSTANCES_NAME) + message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_TEMPLATE_FILE) + message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_NUM_SHARDS) + message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations") + endif() + if(NOT GEN_SHARDED_OUTPUT_DIR) + message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_SRC_LIST) + message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations") + endif() + + file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR}) + + + set(GENERATED_SOURCE_FILES "") + set(EXTERN_TEMPLATE_STATEMENTS "") + set(CALL_STATEMENTS "") + message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") + + set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}") + + # Generate the inc file with the template function defintions. + # This include file will hold the template function definitions and a using alias for all the shard + # instantiation functions. + configure_file( + "${GEN_SHARDED_TEMPLATE_FILE}" + "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc" + @ONLY + ) + + # Generate the sharded instantiation functions. + # This is where the build parallelization happens. + # Each of these source files will contain a single instantiation function for a shard, + # which will be called sequentially by the caller function. + set(INC_DIR "${GEN_SHARDED_INC_DIR}") + math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1") + foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID}) + set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}") + set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp") + set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in") + configure_file( + "${SHARD_FUNCTION_TEMPLATE}" + "${SHARD_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}") + set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>") + list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)") + list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)") + endforeach() + + # Join the include statements, the extern template declarations, and the call statements each + # into a single string for variable substitution in the caller function. + string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}") + string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}") + string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}") + + # Generate the caller function. + set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp") + set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in") + configure_file( + "${FUNCTION_TEMPLATE}" + "${CALLER_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}") + + # Add the generated source files to the list of source files. + # This allows the generated source files to be included in the build. + list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES}) + set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE) +endfunction() \ No newline at end of file diff --git a/cmake/call_shard.in b/cmake/call_shard.in new file mode 100644 index 0000000000..daba79b055 --- /dev/null +++ b/cmake/call_shard.in @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { + +@EXTERN_TEMPLATE_STATEMENTS@; + +void add_@INSTANCES@( + @INSTANCES@& instances) { +@CALL_STATEMENTS@; +} + +} // namespace ck::tensor_operation::device::instance diff --git a/cmake/instantiate_shard.in b/cmake/instantiate_shard.in new file mode 100644 index 0000000000..dbc0af17a9 --- /dev/null +++ b/cmake/instantiate_shard.in @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { +template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>( + @INSTANCES@& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/include/ck/utility/filter_tuple.hpp b/include/ck/utility/filter_tuple.hpp new file mode 100644 index 0000000000..c2e378b879 --- /dev/null +++ b/include/ck/utility/filter_tuple.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/functional.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck::util { + +template +struct filter_tuple_by_modulo +{ + // Validate Stride and Offset. + static_assert(Stride > 0, "Offset must be positive."); + static_assert(Offset >= 0 && Offset < Stride, + "Offset must be positive and less than the stride."); + + // Generate filtered indices for this stride and offset. + static constexpr int new_size = (std::tuple_size_v + Stride - Offset - 1) / Stride; + + template + static constexpr auto to_index(std::index_sequence) + { + return std::index_sequence<(Offset + Is * Stride)...>{}; + } + + using filtered_indices = decltype(to_index(std::make_index_sequence{})); + + // Helper struct to construct the new tuple type from the filtered indices. + template + struct make_filtered_tuple_type_impl; + + template + struct make_filtered_tuple_type_impl> + { + using type = std::tuple...>; + }; + + using type = typename make_filtered_tuple_type_impl::type; +}; + +// Filter a tuple with a stride and offset. +// +// Tuple is a std::tuple or equivalent +// Stride is a positive integer +// Offset is a positive integer smaller than ofset +// +// Evaluates to a smaller tuple type from elements of T with stride M and offset I. +// +// Can be used to filter a tuple of types for sharded instantiations. +template +using filter_tuple_by_modulo_t = typename filter_tuple_by_modulo::type; + +// Example compile-time test: +// using OriginalTuple = +// std::tuple; +// using NewTuple_Every3rdFrom2nd = filter_tuple_by_modulo_t; +// static_assert(std::is_same_v>, +// "Test Case 1 Failed: Every 3rd from 2nd"); + +} // namespace ck::util diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index b018737932..a3f2515099 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -688,7 +688,6 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); - void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances( std::vector>>& instances) + PassThrough>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances_shard([[maybe_unused]] + device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances& instances) { add_device_operation_instances( instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in similarity index 71% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in index 4ca1b2b85e..88c84adfe2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances = std::vector>>& instances) + PassThrough>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances_shard( + device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, + ck::util::filter_tuple_by_modulo_t{}); + ConvFwdDefault>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, + ck::util::filter_tuple_by_modulo_t{}); + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, + ck::util::filter_tuple_by_modulo_t{}); + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in index e3a12fd5f4..13fb583725 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances = std::vector>>& instances) + PassThrough>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances_shard( + device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwd1x1P0>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwd1x1S1P0>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp deleted file mode 100644 index f667481fa4..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault, - Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0, - Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0, - Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in new file mode 100644 index 0000000000..d8b35bda68 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances = + std::vector>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances_shard( + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp deleted file mode 100644 index 2ff2c7f51f..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault, - Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0, - Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0, - Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in new file mode 100644 index 0000000000..125e16139d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances = + std::vector>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances_shard( + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index f8efa5a7c1..1d9d75a104 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -11,8 +11,6 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp @@ -32,23 +30,13 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp +xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp @@ -71,6 +59,99 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp ) +# Add generated files for sharded instantiations. +include(ShardInstantiation) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances + TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in + NUM_SHARDS 8 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl +) +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances + TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in + NUM_SHARDS 8 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl +) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp deleted file mode 100644 index a94f687ef8..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/host_utility/device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in new file mode 100644 index 0000000000..e1a6e6c0c4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances = + std::vector>>; + +template +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp deleted file mode 100644 index 0c63345e7f..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/host_utility/device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in new file mode 100644 index 0000000000..6d196ad71f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances = + std::vector>>; + +template +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp deleted file mode 100644 index 43241454a5..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in new file mode 100644 index 0000000000..4c67e4912c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp deleted file mode 100644 index d02d9f6778..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in new file mode 100644 index 0000000000..0fbefa3bbc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp deleted file mode 100644 index 060eebebc1..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in index f3eccc7dc8..c87783eed9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in @@ -1,15 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp deleted file mode 100644 index 85b088f416..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in index abea0bea81..ca6d571be1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in @@ -1,15 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp new file mode 100644 index 0000000000..da2f3dc1fa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" + +namespace ck::tensor_operation::device::instance { +template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 0>( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp new file mode 100644 index 0000000000..5d551833c0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" + +namespace ck::tensor_operation::device::instance { +template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 1>( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp new file mode 100644 index 0000000000..715cbf6beb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" + +namespace ck::tensor_operation::device::instance { +template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 2>( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp new file mode 100644 index 0000000000..cf2a9f4023 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" + +namespace ck::tensor_operation::device::instance { +template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 3>( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp new file mode 100644 index 0000000000..085b2904d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" + +namespace ck::tensor_operation::device::instance { +template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 4>( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp new file mode 100644 index 0000000000..18b1e0c6d9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" + +namespace ck::tensor_operation::device::instance { +template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 5>( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp new file mode 100644 index 0000000000..b95f1d1229 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" + +namespace ck::tensor_operation::device::instance { +template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 6>( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp new file mode 100644 index 0000000000..afe3e5d19f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" + +namespace ck::tensor_operation::device::instance { +template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 7>( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in new file mode 100644 index 0000000000..2586bc0f16 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in new file mode 100644 index 0000000000..7405f86a5f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in index ba5d9fb1de..24d6b66976 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in similarity index 57% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in index fac3098341..91a2444241 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in @@ -3,53 +3,60 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in index 5a2c4a0d5b..7571dff883 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in index 701b8eb4a4..38ed240fab 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.inc b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.inc new file mode 100644 index 0000000000..38ed240fab --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.inc @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances& instances) +{ + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance From bd96ac9742b9e7da08b9e8a26e0b40d10c54e574 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Fri, 13 Jun 2025 19:39:11 +0200 Subject: [PATCH 005/103] [CK_TILE] Multiple-D GEMM example (#2219) * Multiple d, initial commit * Check Ds Layout * Readme and clang format * Update branch & conflicts * Multiple D - fix clang-formatter * Rename elemetwise_op * Fix CI * Code review part1 * Remove printf * Remove unnecessary comment * Add new tests with Col layout * Review part 2 * Added support for Multiple D GEMM * Update comment * Remove maybe_unused * Clang-format * Review part 3 * Add comment to function * Add comment to function: another * Take number of params for a refrence function * Remove additional d param for 0 tensor * Change name of function * Fix CI fails --- CHANGELOG.md | 1 + example/ck_tile/03_gemm/gemm_basic.cpp | 10 +- example/ck_tile/03_gemm/gemm_utils.hpp | 7 +- example/ck_tile/03_gemm/run_gemm_example.inc | 101 +++-- example/ck_tile/03_gemm/universal_gemm.cpp | 23 +- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 16 +- .../ck_tile/16_batched_gemm/batched_gemm.hpp | 1 + .../run_batched_gemm_example.inc | 68 ++- example/ck_tile/17_grouped_gemm/README.md | 2 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 14 +- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 17 +- .../run_grouped_gemm_example.inc | 58 ++- .../ck_tile/19_gemm_multi_d/CMakeLists.txt | 1 + example/ck_tile/19_gemm_multi_d/README.md | 35 ++ .../19_gemm_multi_d/gemm_multi_d_fp16.cpp | 296 +++++++++++++ .../19_gemm_multi_d/gemm_multi_d_fp16.hpp | 79 ++++ .../run_gemm_multi_d_fp16_example.inc | 247 +++++++++++ example/ck_tile/19_gemm_multi_d/utils.hpp | 50 +++ example/ck_tile/CMakeLists.txt | 1 + .../ck_tile/core/tensor/tile_elementwise.hpp | 32 ++ .../ck_tile/host/reference/reference_gemm.hpp | 52 +++ .../unary_element_wise_operation.hpp | 1 + .../ops/epilogue/cshuffle_epilogue.hpp | 101 ++++- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 44 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 385 ++++++++++++----- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 62 +-- test/ck_tile/CMakeLists.txt | 1 + .../batched_gemm/test_batched_gemm_util.hpp | 12 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 16 +- test/ck_tile/gemm_multi_d/CMakeLists.txt | 4 + .../gemm_multi_d/test_gemm_multi_d.cpp | 39 ++ .../test_gemm_multi_d_ut_cases.inc | 334 ++++++++++++++ .../gemm_multi_d/test_gemm_multi_d_util.hpp | 407 ++++++++++++++++++ .../grouped_gemm/test_grouped_gemm_util.hpp | 35 +- 34 files changed, 2267 insertions(+), 285 deletions(-) create mode 100644 example/ck_tile/19_gemm_multi_d/CMakeLists.txt create mode 100644 example/ck_tile/19_gemm_multi_d/README.md create mode 100644 example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp create mode 100644 example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp create mode 100644 example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc create mode 100644 example/ck_tile/19_gemm_multi_d/utils.hpp create mode 100644 test/ck_tile/gemm_multi_d/CMakeLists.txt create mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp create mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc create mode 100644 test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index af8d965b30..368d1e502d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added support for Multiple D GEMM * Added GEMM pipeline for microscaling (MX) FP8/FP4 data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index de9608bcb4..defeffc2ee 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -14,13 +14,17 @@ template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + bool Persistent, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { if constexpr(Persistent) std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; @@ -53,8 +57,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using CodegenGemmTraits = ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; const auto Run = [&](const auto memory_operation_) { diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index f3d11c751b..6987a2492e 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -252,10 +252,13 @@ auto create_args(int argc, char* argv[]) // host API template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); + bool Persistent = false, + typename CDEElementWise> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index bf455a6415..cc9a825c73 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -146,11 +146,14 @@ void permute_vectors_i4x4_b(Tensor& tensor) template + typename DsLayout, + typename CLayout, + typename CDEElementWise = ck_tile::element_wise::PassThrough> float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -165,41 +168,48 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_repeat, bool persistent) { - ck_tile::GemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_C = stride_C; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + {}, + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C}; float ave_time; if(persistent) { - ave_time = gemm_calc( + ave_time = gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); } else { - ave_time = gemm_calc( + ave_time = gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); } @@ -328,20 +338,27 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm( - a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat, - persistent); + invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat, + persistent); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index fafe40c333..beb6987605 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -15,13 +15,17 @@ template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + typename DsLayout, + typename ELayout, + bool Persistent, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -30,24 +34,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& sequence, GemmConfig::PermuteA, GemmConfig::PermuteB>; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; + ELayout>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits +template float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) @@ -123,12 +132,16 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre tail_number_v>; using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 0999c7ad3b..78d915e873 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 16a31e519a..7d5e1910dd 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -23,7 +23,16 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -44,20 +53,29 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::BatchedGemmHostArgs args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; args.stride_A = stride_A; args.stride_B = stride_B; - args.stride_C = stride_C; + args.stride_E = stride_C; args.batch_stride_A = batch_stride_A; args.batch_stride_B = batch_stride_B; - args.batch_stride_C = batch_stride_C; + args.batch_stride_E = batch_stride_C; args.batch_count = batch_count; - float ave_time = batched_gemm( + float ave_time = batched_gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::string op_name{"Batched Gemm"}; @@ -169,22 +187,30 @@ int run_batched_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_batched_gemm(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - batch_stride_A, - batch_stride_B, - batch_stride_C, - batch_count, - kbatch, - n_warmup, - n_repeat); + invoke_batched_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count, + kbatch, + n_warmup, + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index d1a0458eda..59396a558b 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -1,6 +1,6 @@ # Grouped CShuffle GEMM -This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile. +This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. ## build ``` diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 2a72c6325e..85d75320c5 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -16,7 +16,16 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -template +template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) @@ -130,9 +139,12 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem; auto create_args(int argc, char* argv[]) { @@ -82,7 +83,17 @@ inline std::size_t get_workspace_size(const std::vector& gem return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } -template +template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index a01d8178cc..5ed1219731 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -30,7 +30,17 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(int n_warmup, int n_repeat, int group_count, @@ -44,7 +54,16 @@ float invoke_gemm(int n_warmup, if constexpr(!Persistent) { // Regular version of grouped gemm - ave_time = grouped_gemm( + ave_time = grouped_gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, gemm_workspace.GetDeviceBuffer()); @@ -64,16 +83,18 @@ float invoke_gemm(int n_warmup, const bool splitk = args[0].k_batch > 1; for(const auto& arg : args) { - kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.c_ptr, - arg.M, - arg.N, - arg.K, - arg.stride_A, - arg.stride_B, - arg.stride_C, - arg.k_batch}); + kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr, + arg.b_ptr, + {}, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.stride_A, + arg.stride_B, + {}, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, @@ -219,10 +240,19 @@ int run_grouped_gemm_example_with_layouts(int argc, void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); gemm_descs.push_back( - {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + {p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]}); } - invoke_gemm(warmup, repeat, group_count, gemm_descs); + invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout, + Persistent>(warmup, repeat, group_count, gemm_descs); for(int i = 0; i < group_count; i++) { diff --git a/example/ck_tile/19_gemm_multi_d/CMakeLists.txt b/example/ck_tile/19_gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000..e2e68b325a --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/CMakeLists.txt @@ -0,0 +1 @@ +add_executable(tile_example_gemm_multi_d_fp16 EXCLUDE_FROM_ALL gemm_multi_d_fp16.cpp) diff --git a/example/ck_tile/19_gemm_multi_d/README.md b/example/ck_tile/19_gemm_multi_d/README.md new file mode 100644 index 0000000000..7e8cd87546 --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/README.md @@ -0,0 +1,35 @@ +#Multiple D GEMM + +This folder contains example for Multiple D GEMM using ck_tile tile-programming implementation. + +## build +``` +#in the root of ck_tile +mkdir build && cd build +#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \ + leave it blank +sh ../script/cmake-ck-dev.sh ../ +#The basic pipeline method on the gemm calculation +make tile_example_gemm_multi_d_fp16 -j +``` +This will result in an executable `build/bin/tile_example_gemm_multi_d_fp16` + +## example +``` +args: + -m M dimensions - (Default: 3840) + -n N dimensions - (Default: 4096) + -k K dimensions - (Default: 4096) +-a_layout Tensor A layout (default:R) +-b_layout Tensor B layout (default:C) +-ds_layout Tensor D layout (default:R) +-e_layout Tensor E layout (default:R) +-stride_a Tensor A strides - (Default: 0) +-stride_b Tensor B strides - (Default: 0) +-stride_e Tensor C strides - (Default: 0) +-stride_ds Tensor D strides - (Default: 0) +-validate 0. No validation, 1. Validation on GPU. (Default: 1) + -warmup Number of iterations before benchmark the kernel. (Default: 10) + -repeat Number of iterations to benchmark the kernel. (Default: 100) + -kbatch kbatch for SplitK. (Default 1) +``` diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp new file mode 100644 index 0000000000..6c5ca08426 --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "gemm_multi_d_fp16.hpp" +#include "utils.hpp" + +template +auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& s) -> float +{ +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Memory friendly for Interwave scheduler + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + constexpr bool DoubleSmemBuffer = false; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = true; +#endif + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + if(has_hot_loop) + { +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" << tail_num + << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + if(tail_num == ck_tile::TailNumber::One) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + auto check_tail = [&](auto... TNs) { + (try_run(tail_num), ...); + }; + + check_tail(ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}); + +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + if(tail_num == ck_tile::TailNumber::Three) + { + RunSplitk( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +#endif + } + else + { + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} + +#include "run_gemm_multi_d_fp16_example.inc" + +int main(int argc, char* argv[]) { return !run_multiple_d_gemm_example(argc, argv); } diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp new file mode 100644 index 0000000000..3ce3965e56 --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + +using ADataType = ck_tile::half_t; +using BDataType = ck_tile::half_t; +using D0DataType = ck_tile::half_t; +using D1DataType = ck_tile::half_t; +using EDataType = ck_tile::half_t; +using DsDataType = ck_tile::tuple; +using AccDataType = float; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "4096", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Col by default") + .insert("ds_layout", "R", "Ds tensor data layout - Row by default") + .insert("e_layout", "R", "E tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_ds", "0", "Tensor Ds stride") + .insert("stride_e", "0", "Tensor E stride") + .insert("v", "1", "0. No validation, 1. Validation on GPU") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("kbatch", "1", "kbatch for SplitK"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +using gemm_multi_d_kargs = ck_tile::GemmHostArgs; + +template +float gemm_multi_d(const gemm_multi_d_kargs& kargs, const ck_tile::stream_config& s); diff --git a/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc b/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc new file mode 100644 index 0000000000..a0d7157d03 --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +template +float invoke_gemm_multi_d(const void* a_m_k_dev_buf, + const void* b_k_n_dev_buf, + const std::array& ds_m_n_dev_buf, + void* e_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t StrideA, + ck_tile::index_t StrideB, + const std::array& StrideDs, + ck_tile::index_t StrideE, + int n_warmup, + int n_repeat, + int k_batch) +{ + gemm_multi_d_kargs gemm_descs({a_m_k_dev_buf, + b_k_n_dev_buf, + ds_m_n_dev_buf, + e_m_n_dev_buf, + k_batch, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE}); + + float ave_time = gemm_multi_d( + gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Gemm Multiple-D"}; + static constexpr ck_tile::index_t NumDTensor = DsDataType::size(); + + std::size_t flop = 0, num_btype = 0; + + flop += std::size_t(2) * M * N * K; + + ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) { + num_btype += sizeof(ck_tile::remove_cvref_t>) * M * N; + flop += sizeof(ck_tile::remove_cvref_t>) * M * N; + }); + + num_btype += sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Run Gemm Multiple-D kernel with:\n"; + std::cout << "M =" << M << " N =" << N << " K =" << K << "\n"; + std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE + << "\n"; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << "\n"; + + return ave_time; +} + +template +int run_multiple_d_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + const D0Layout d0_layout = D0Layout{}, + const D1Layout d1_layout = D1Layout{}, + const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + using CDElementWiseFn = MultiplyMultiply; + using DsLayout = ck_tile::tuple; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t StrideA = arg_parser.get_int("stride_a"); + ck_tile::index_t StrideB = arg_parser.get_int("stride_b"); + ck_tile::index_t StrideD = arg_parser.get_int("stride_ds"); + ck_tile::index_t StrideE = arg_parser.get_int("stride_e"); + + ck_tile::index_t StrideD0 = StrideD; + ck_tile::index_t StrideD1 = StrideD; + + const int n_warmup = arg_parser.get_int("warmup"); + const int n_repeat = arg_parser.get_int("repeat"); + const int k_batch = arg_parser.get_int("kbatch"); + + StrideA = get_default_stride(M, K, StrideA, is_row_major(a_layout)); + StrideB = get_default_stride(K, N, StrideB, is_row_major(b_layout)); + StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout)); + StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout)); + StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout)); + + ck_tile::HostTensor a_m_k_tesnor( + host_tensor_descriptor(M, K, StrideA, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n_tensors( + host_tensor_descriptor(K, N, StrideB, is_row_major(b_layout))); + ck_tile::HostTensor d0_m_n_tensors( + host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout))); + ck_tile::HostTensor d1_m_n_tensors( + host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout))); + ck_tile::HostTensor e_m_n_device_result( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_tesnor); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k_tesnor.mData.data()); + b_k_n_dev_buf.ToDevice(b_k_n_tensors.mData.data()); + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + + std::array stridesDs = {StrideD0, StrideD1}; + + invoke_gemm_multi_d(a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + stridesDs, + StrideE, + n_warmup, + n_repeat, + k_batch); + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + + ck_tile::HostTensor e_m_n_host_ref( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_d( + a_m_k_tesnor, b_k_n_tensors, {d0_m_n_tensors, d1_m_n_tensors}, e_m_n_host_ref); + + bool pass{true}; + if(arg_parser.get_int("v")) + { + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + + const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value); + + pass &= ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << std::endl; + std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + return pass; +} + +int run_multiple_d_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string ds_layout = arg_parser.get_str("ds_layout"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(a_layout == "R" && b_layout == "C" && ds_layout == "R") + { + return run_multiple_d_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for provided tensors!"); + } +} diff --git a/example/ck_tile/19_gemm_multi_d/utils.hpp b/example/ck_tile/19_gemm_multi_d/utils.hpp new file mode 100644 index 0000000000..a201d11ffc --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/utils.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +struct MultiplyMultiply +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index d479cd35f6..f2f39b6e17 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -18,5 +18,6 @@ add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) +add_subdirectory(19_gemm_multi_d) add_subdirectory(35_batched_transpose) add_subdirectory(36_copy) diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 79018b9ced..d2b24ad54e 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -59,6 +59,38 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func, return out_dstr_tensor; } +/** + * @brief Template function that "unpacks" a tuple and applies an element-wise operation. + * + * @param in_element_func Function to apply element-wise. + * @param t Any container containing elements to process, with known size and + * tuple-like semantic. + * @return Calls tile_elementwise_inout with unpacked tuple elements. + */ +template +CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func, + const Tuple& t, + std::index_sequence) +{ + return tile_elementwise_inout(in_element_func, t[number{}]...); +} + +/** + * @brief Template function that "unpacks" a tuple and applies an element-wise operation. + * + * @param in_element_func Function to apply element-wise. + * @param t Any container containing elements to process, with known size and + * tuple-like semantic. + * @return Calls the overloaded function, passing an index sequence. + */ +template +CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func, + const Tuple& t) +{ + static constexpr auto size = Tuple::size(); + return tile_elementwise_inout_unpack(in_element_func, t, std::make_index_sequence{}); +} + template CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value) { diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index fe5077083c..c88deaec01 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -71,6 +71,58 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } +template >> +CK_TILE_HOST void +reference_gemm_multiple_d(const HostTensor& a_m_k, + const HostTensor& b_k_n, + const std::array, DsDataType::size()>& ds_m_n, + HostTensor& c_m_n, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mk_kn_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + for(std::size_t k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_k_n(k, n); + v_acc += + ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + } + + CDataType v_c = 0; + if constexpr(DsDataType::size() == 0) + { + acc_element_op(v_c, ck_tile::type_convert(v_acc)); + } + else if constexpr(DsDataType::size() == 1) + { + acc_element_op(v_c, + ck_tile::type_convert(v_acc), + ck_tile::type_convert(ds_m_n[0](m, n))); + } + else if constexpr(DsDataType::size() == 2) + { + acc_element_op(v_c, + ck_tile::type_convert(v_acc), + ck_tile::type_convert(ds_m_n[0](m, n)), + ck_tile::type_convert(ds_m_n[1](m, n))); + } + c_m_n(m, n) = ck_tile::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency()); +} + template CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); } }; #endif + } // namespace element_wise } // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 6613ceebb2..68e91520bf 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -11,9 +11,12 @@ namespace ck_tile { template ; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; - using CLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; static constexpr index_t kBlockSize = kBlockSize_; static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; @@ -43,6 +49,10 @@ struct CShuffleEpilogueProblem static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; + static constexpr index_t NumDTensor = DsDataType::size(); + + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); }; template @@ -53,10 +63,13 @@ struct CShuffleEpilogue using BDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; - using CLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = Problem::kMPerBlock; @@ -69,7 +82,10 @@ struct CShuffleEpilogue static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; + static constexpr index_t NumDTensor = Problem::NumDTensor; + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); /** * @brief Get the vector store size for C tensor. * @@ -83,22 +99,49 @@ struct CShuffleEpilogue CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() { constexpr index_t max_vector_size = 16; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return std::min(static_cast(NPerIteration), static_cast(max_vector_size / sizeof(ODataType))); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return std::min(static_cast(MPerIteration), static_cast(max_vector_size / sizeof(ODataType))); } else { - static_assert(false, "Unsupported CLayout!"); + static_assert(false, "Unsupported ELayout!"); } } + /** + * @brief Get the vector store size for Di tensor. + * + * @return The vector store size for Di tensor. + */ + template + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number index) + { + constexpr index_t max_vector_size = 16; + using DiDataType = remove_cvref_t>; + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return std::min(static_cast(NPerIteration), + static_cast(max_vector_size / sizeof(DiDataType))); + } + else if constexpr(std::is_same_v) + { + return std::min(static_cast(MPerIteration), + static_cast(max_vector_size / sizeof(DiDataType))); + } + else + { + static_assert(false, "Unsupported DLayout!"); + } + return max_vector_size / sizeof(DiDataType); + } /** * @brief Shuffle tile configuration parameters * @@ -116,7 +159,7 @@ struct CShuffleEpilogue else { constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && (kMPerBlock % num_xdl_shuffles == 0), @@ -147,7 +190,8 @@ struct CShuffleEpilogue }(); static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); - using WG = WarpGemmMfmaDispatcher) + if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( make_tuple(number{}, number{}), make_tuple(number{}, number<1>{})); } // M is contiguous dimension - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( make_tuple(number{}, number{}), @@ -177,7 +221,7 @@ struct CShuffleEpilogue } else { - static_assert(false, "Unsupported CLayout!"); + static_assert(false, "Unsupported ELayout!"); } } @@ -202,9 +246,11 @@ struct CShuffleEpilogue return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); } - template - CK_TILE_DEVICE auto - operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) + template + CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, + const OAccTile& o_acc_tile, + const DsDramWindows& ds_dram_windows, + void* p_smem) { constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); @@ -230,7 +276,7 @@ struct CShuffleEpilogue sequence>; constexpr index_t num_access = SFC::get_num_of_access(); - static_assert(std::is_same_v, + static_assert(std::is_same_v, "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); using TileEncodingPattern = @@ -242,6 +288,12 @@ struct CShuffleEpilogue Problem::kNumWaveGroups>; constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + auto d_dram_windows = generate_tuple( + [&](auto idx) { + return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); + }, + number{}); + constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; @@ -265,8 +317,17 @@ struct CShuffleEpilogue store_tile(in_lds_window, c_warptile_in_tensor_casted); block_sync_lds(); - const auto c_out_tensor = - load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); + auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); + + const auto ds_tensor = generate_tuple( + [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); + + const auto c_ds_tiles = concat_tuple_of_reference( + tie(c_out_tensor, c_out_tensor), + generate_tie( + [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + + tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); if constexpr(MemoryOperation == memory_operation_enum::set) { @@ -279,7 +340,13 @@ struct CShuffleEpilogue if constexpr(iAccess != num_access - 1) { constexpr auto step = SFC::get_forward_step(iAccess); + move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); + + static_for<0, NumDTensor, 1>{}([&](auto idx) { + move_tile_window(d_dram_windows[idx], + {step.at(number<0>{}), step.at(number<1>{})}); + }); } }); } diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index d495c0d950..09c7d58558 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -9,7 +9,7 @@ namespace ck_tile { -struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs +struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs { CK_TILE_HOST BatchedGemmHostArgs() = default; CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_, @@ -26,18 +26,28 @@ struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs ck_tile::index_t batch_stride_B_, ck_tile::index_t batch_stride_C_, ck_tile::index_t batch_count_) - : GemmHostArgs( - a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_), + : GemmHostArgs(a_ptr_, + b_ptr_, + {}, + c_ptr_, + k_batch_, + M_, + N_, + K_, + stride_A_, + stride_B_, + {}, + stride_C_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), - batch_stride_C(batch_stride_C_), + batch_stride_E(batch_stride_C_), batch_count(batch_count_) { } ck_tile::index_t batch_stride_A; ck_tile::index_t batch_stride_B; - ck_tile::index_t batch_stride_C; + ck_tile::index_t batch_stride_E; ck_tile::index_t batch_count; }; @@ -46,18 +56,18 @@ struct BatchedGemmKernel : public GemmKernel; - using GemmKernelArgs = typename ck_tile::GemmKernelArgs; + using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>; using ADataType = typename Base::ADataType; using BDataType = typename Base::BDataType; - using CDataType = typename Base::CDataType; + using CDataType = typename Base::EDataType; using TilePartitioner = typename Base::TilePartitioner; using GemmPipeline = typename Base::GemmPipeline; using EpiloguePipeline = typename Base::EpiloguePipeline; using ALayout = typename Base::ALayout; using BLayout = typename Base::BLayout; - using CLayout = typename Base::CLayout; + using CLayout = typename Base::ELayout; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -75,7 +85,7 @@ struct BatchedGemmKernel : public GemmKernel(kargs.b_ptr) + batch_offset_B + splitk_batch_offset.b_k_split_offset; - const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C); - const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); - CDataType* c_ptr = static_cast(kargs.c_ptr) + batch_offset_C; + const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E); + const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E); + CDataType* c_ptr = static_cast(kargs.e_ptr) + batch_offset_C; // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index bfb0d2626b..4cd26c2234 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -16,70 +16,72 @@ namespace ck_tile { -/// @brief The GEMM problem definition. -/// -/// @par Overview -/// This structure defines the GEMM problem configuration by stating all required information -/// like M,N,K sizes and respective strides. -struct GemmProblem -{ - CK_TILE_HOST GemmProblem() = default; - CK_TILE_HOST GemmProblem( - index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) - : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) - { - } - - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; -}; - /// @brief The GEMM kernel host arguments. /// /// @par Overview /// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments /// object. It contain all necessary information required to build proper kernel argument /// and launch kernel on GPU. -struct GemmHostArgs : public GemmProblem +/// This structure defines the GEMM problem configuration by stating all required information +/// like M,N,K sizes and respective strides. +/// NumDTensor describes the number of D tensors. +template +struct GemmHostArgs { CK_TILE_HOST GemmHostArgs() = default; CK_TILE_HOST GemmHostArgs(const void* a_ptr_, const void* b_ptr_, - void* c_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, - index_t stride_C_) - : GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), - a_ptr(a_ptr_), + const std::array& stride_Ds_, + index_t stride_E_) + : a_ptr(a_ptr_), b_ptr(b_ptr_), - c_ptr(c_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), k_batch(k_batch_) { } const void* a_ptr; const void* b_ptr; - void* c_ptr; + const std::array ds_ptr; + void* e_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + const std::array stride_Ds; + index_t stride_E; index_t k_batch; }; /// @brief The GEMM kernel device arguments. +template struct GemmKernelArgs { /// @brief The A input tensor's pointer to device memory. const void* a_ptr; /// @brief The B input tensor's pointer to device memory. const void* b_ptr; - /// @brief The C output tensor's pointer to device memory. - void* c_ptr; + /// @brief The Ds input tensor's pointer to device memory. + const std::array ds_ptr; + /// @brief The E output tensor's pointer to device memory. + void* e_ptr; /// @brief GEMM's M dimension size. index_t M; /// @brief GEMM's N dimension size. @@ -93,8 +95,11 @@ struct GemmKernelArgs /// (in memory) of B tensor. index_t stride_B; /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of C tensor. - index_t stride_C; + /// (in memory) of Ds tensor. + std::array stride_Ds; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of E tensor. + index_t stride_E; index_t k_batch; }; @@ -133,16 +138,19 @@ struct GemmKernelArgs /// @tparam EpiloguePipeline_ The type of class providing the final part of matrix /// multiplication implementation. It is responsible for storing /// results calculated by @ref GemmPipeline_ "GemmPipeline" to -/// the output C tensor in global memory. +/// the output E tensor in global memory. template struct GemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + // TODO: GemmPipeline::CLayout -> GemmPipeline::ELayout will be changed for multi-ABD + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; // Get the persistent kernel if the pipeline has it available @@ -163,11 +171,18 @@ struct GemmKernel using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. - using CDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + + static constexpr index_t NumDTensor = DsDataType::size(); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>{}; + + static_assert(DsLayout::size() == DsDataType::size(), + "The size of DsLayout and DsDataType should be the same"); + using KernelArgs = GemmKernelArgs; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -190,7 +205,7 @@ struct GemmKernel CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { using Kernel = GemmKernel; - const auto kernel = kentry; + const auto kernel = kentry; int occupancy; hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); @@ -200,18 +215,22 @@ struct GemmKernel CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) + CK_TILE_HOST static constexpr KernelArgs + MakeKernelArgs(const GemmHostArgs& hostArgs) { - return GemmKernelArgs{hostArgs.a_ptr, - hostArgs.b_ptr, - hostArgs.c_ptr, - hostArgs.M, - hostArgs.N, - hostArgs.K, - hostArgs.stride_A, - hostArgs.stride_B, - hostArgs.stride_C, - hostArgs.k_batch}; + + return KernelArgs{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_Ds, + hostArgs.stride_E, + hostArgs.k_batch}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -221,8 +240,7 @@ struct GemmKernel struct SplitKBatchOffset { - __device__ SplitKBatchOffset(const GemmKernelArgs& kargs, - const std::size_t k_id = blockIdx.z) + __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); @@ -261,10 +279,10 @@ struct GemmKernel index_t splitted_k; }; - CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) + CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) { if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value) + is_any_of::value) { if(kargs.k_batch != 1) { @@ -360,7 +378,56 @@ struct GemmKernel } } - if constexpr(std::is_same_v) + bool DTesnorIsValid = {true}; + static_for<0, NumDTensor, 1>{}([&](auto index) { + using DiLayout = remove_cvref_t>; + if(std::is_same_v == false) + { + DTesnorIsValid = false; + } + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " + "NPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " + "MPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + }); + + if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { @@ -400,15 +467,17 @@ struct GemmKernel return false; } } - return true; + return DTesnorIsValid; } template - CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - const GemmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) { static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); const auto& a_tensor_view = [&]() { @@ -495,29 +564,54 @@ struct GemmKernel } }(); + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + // TODO: enable vector write for C in ColMajor - const auto& c_tensor_view = [&]() { - if constexpr(std::is_same_v) + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) { return make_naive_tensor_view( - c_ptr, + e_ptr, make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), + make_tuple(kargs.stride_E, 1), number{}, number<1>{}); } else { return make_naive_tensor_view( - c_ptr, + e_ptr, make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), + make_tuple(1, kargs.stride_E), number<1>{}, number<1>{}); } }(); - return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, e_tensor_view); } template @@ -559,35 +653,57 @@ struct GemmKernel } }(); + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + const auto& d_tensor_view = views.at(I2); + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + // TODO vector write in for C in ColMajor - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I2); - if constexpr(std::is_same_v) + const auto& e_pad_view = [&]() { + const auto& e_tensor_view = views.at(I3); + if constexpr(std::is_same_v) { - return pad_tensor_view(c_tensor_view, + return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), sequence{}); } else { - return pad_tensor_view(c_tensor_view, + return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), sequence{}); } }(); - return make_tuple(a_pad_view, b_pad_view, c_pad_view); + return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view); } template CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& c_pad_view = views.at(I2); + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& e_pad_view = views.at(I3); const auto& a_block_window = [&]() { if constexpr(std::is_same_v) @@ -623,12 +739,32 @@ struct GemmKernel } }(); - auto c_block_window = make_tile_window( - c_pad_view, + const auto ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_m}); + } + }, + number{}); + + auto e_block_window = make_tile_window( + e_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - return make_tuple(a_block_window, b_block_window, c_block_window); + return make_tuple(a_block_window, b_block_window, ds_block_window, e_block_window); } /** @@ -636,7 +772,8 @@ struct GemmKernel * * @param a_ptr input A pointer * @param b_ptr input B pointer - * @param c_ptr output C pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer * @param smem_ptr_0 The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. @@ -647,9 +784,10 @@ struct GemmKernel template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, - CDataType* c_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, void* smem_ptr_0, - const GemmKernelArgs& kargs, + const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -657,7 +795,7 @@ struct GemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -668,6 +806,7 @@ struct GemmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); @@ -675,11 +814,11 @@ struct GemmKernel if(UseDefaultScheduler || (get_warp_id() == 0)) { // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); + auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } } @@ -690,7 +829,8 @@ struct GemmKernel * * @param a_ptr input A pointer * @param b_ptr input B pointer - * @param c_ptr output C pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer * @param smem_ptr_0 The starting pointer of 1st shared memory block. * @param smem_ptr_1 The starting pointer of 2nd shared memory block. * @param kargs GEMM kernel arguments @@ -701,10 +841,11 @@ struct GemmKernel */ CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, const BDataType* b_ptr, - CDataType* c_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, void* __restrict__ smem_ptr_0, void* __restrict__ smem_ptr_1, - const GemmKernelArgs& kargs, + const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -712,7 +853,8 @@ struct GemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -722,20 +864,22 @@ struct GemmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); + auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } // Non-persistent kernel entry point template > - CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const + CK_TILE_DEVICE void operator()(KernelArgs kargs) const { const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); @@ -743,12 +887,14 @@ struct GemmKernel const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); + // options const ADataType* a_ptr = static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - CDataType* c_ptr = static_cast(kargs.c_ptr); + + EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; @@ -758,11 +904,12 @@ struct GemmKernel __shared__ char smem_ptr_1[GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { RunGemm2LDS(a_ptr, b_ptr, - c_ptr, + kargs.ds_ptr, + e_ptr, smem_ptr_0, smem_ptr_1, kargs, @@ -775,18 +922,25 @@ struct GemmKernel { if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + RunGemm(a_ptr, + b_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } } } // Persistent kernel entry point template , typename = void> - CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const + CK_TILE_DEVICE void operator()(KernelArgs kargs) const { const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); const auto num_tiles = @@ -809,7 +963,7 @@ struct GemmKernel static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - CDataType* c_ptr = static_cast(kargs.c_ptr); + EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; @@ -820,11 +974,12 @@ struct GemmKernel if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { RunGemm2LDS(a_ptr, b_ptr, - c_ptr, + kargs.ds_ptr, + e_ptr, smem_ptr_0, smem_ptr_1, kargs, @@ -838,9 +993,17 @@ struct GemmKernel if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { - RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + RunGemm(a_ptr, + b_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } } // Advance to the next work item diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index f57600d7a5..533cabb736 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -18,17 +18,17 @@ namespace ck_tile { struct GemmTransKernelArg { - GemmKernelArgs group_karg; + GemmKernelArgs<> group_karg; ck_tile::index_t block_start; ck_tile::index_t block_end; - GemmTransKernelArg() = default; - GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end) + GemmTransKernelArg() = delete; + GemmTransKernelArg(GemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end) : group_karg{karg}, block_start{bl_start}, block_end{bl_end} { } - GemmTransKernelArg(GemmKernelArgs&& karg) : group_karg{karg}, block_start{0}, block_end{0} {} + GemmTransKernelArg(GemmKernelArgs<>&& karg) : group_karg{karg}, block_start{0}, block_end{0} {} }; template @@ -39,7 +39,7 @@ struct GroupedGemmKernel : public GemmKernel; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using ELayout = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -65,8 +65,8 @@ struct GroupedGemmKernel : public GemmKernel& gemm_descs) - -> std::size_t + CK_TILE_HOST static auto + GetWorkSpaceSize(const std::vector>& gemm_descs) -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } @@ -95,7 +95,8 @@ struct GroupedGemmKernel : public GemmKernel& gemm_descs) + CK_TILE_HOST static constexpr auto + GridSize(const std::vector>& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -106,7 +107,8 @@ struct GroupedGemmKernel : public GemmKernel& gemm_descs) + CK_TILE_HOST static auto + MakeKargs(const std::vector>& gemm_descs) -> std::vector { std::vector gemm_kernel_args_; @@ -127,7 +129,7 @@ struct GroupedGemmKernel : public GemmKernel(gemm_descs[i].a_ptr), - type_convert(gemm_descs[i].b_ptr), - type_convert(gemm_descs[i].c_ptr), - M, - N, - K, - stride_a, - stride_b, - stride_c, - gemm_descs[i].k_batch}; + auto karg = GemmKernelArgs<>{type_convert(gemm_descs[i].a_ptr), + type_convert(gemm_descs[i].b_ptr), + {}, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + stride_a, + stride_b, + {}, + stride_e, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -177,7 +181,7 @@ struct GroupedGemmKernel : public GemmKernel& kargs, const tuple& block_idx_2d, const index_t block_idx_z) const { @@ -192,7 +196,7 @@ struct GroupedGemmKernel : public GemmKernel(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - CDataType* c_ptr = static_cast(kargs.c_ptr); + CDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; @@ -204,7 +208,7 @@ struct GroupedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } } @@ -230,7 +234,7 @@ struct GroupedGemmKernel : public GemmKernel& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -238,13 +242,14 @@ struct GroupedGemmKernel : public GemmKernel( - a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); const auto& a_block_window = gemm_tile_windows.at(Base::I0); const auto& b_block_window = gemm_tile_windows.at(Base::I1); + const auto& d_block_window = gemm_tile_windows.at(Base::I2); // Get hot-loop and tail configuration const index_t num_loop = __builtin_amdgcn_readfirstlane( @@ -256,9 +261,10 @@ struct GroupedGemmKernel : public GemmKernel( - c_block_window, c_block_tile, smem_ptr_0); + auto& c_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr, diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 8f9d7ac89b..57afb5cbb5 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -2,4 +2,5 @@ add_subdirectory(image_to_column) add_subdirectory(gemm) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) +add_subdirectory(gemm_multi_d) add_subdirectory(data_type) diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index cffa81d1c5..79bd51d65c 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" template class TestCkTileBatchedGemm : public ::testing::Test @@ -23,6 +24,8 @@ class TestCkTileBatchedGemm : public ::testing::Test using BDataType = std::tuple_element_t<4, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; template void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args, @@ -102,9 +105,12 @@ class TestCkTileBatchedGemm : public ::testing::Test using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(args, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index b3146b5f8e..5f2a53645d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -76,12 +76,17 @@ class TestCkTileGemmPipeline : public ::testing::Test using CDataType = std::tuple_element_t<6, Tuple>; static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value; + + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; + static constexpr bool Persistent = ck_tile::tuple_element_or_default_t::value; // TODO: expose tile size through test t-param ? template - void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + void invoke_gemm(const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 256; @@ -165,9 +170,12 @@ class TestCkTileGemmPipeline : public ::testing::Test using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; args.stride_A = stride_A; args.stride_B = stride_B; - args.stride_C = stride_C; + args.stride_E = stride_C; invoke_gemm(args, ck_tile::stream_config{nullptr, false}); diff --git a/test/ck_tile/gemm_multi_d/CMakeLists.txt b/test/ck_tile/gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000..1ec77eb87a --- /dev/null +++ b/test/ck_tile/gemm_multi_d/CMakeLists.txt @@ -0,0 +1,4 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_ck_tile_gemm_multi_d test_gemm_multi_d.cpp) +endif() diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp new file mode 100644 index 0000000000..a634d825b7 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_d_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, CDataType, CDElementWiseFn + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd>, + + std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes); + +#include "test_gemm_multi_d_ut_cases.inc" diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc new file mode 100644 index 0000000000..22d887fa83 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc @@ -0,0 +1,334 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x256x512) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x768x512) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x1280x512) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x1280x512) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_768x512x512) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x512x512) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x256x512) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x512x512) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x256x512) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x256x512) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x768x512) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x1280x512) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x1280x512) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_768x512x512) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x512x512) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x256x512) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp new file mode 100644 index 0000000000..7dd91077b1 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +struct ElementWiseAddAdd +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) + ck_tile::type_convert(d0) + + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +struct MultiplyMultiply +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +class TestCkTileGemmMultiD : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using D0Layout = std::tuple_element_t<2, Tuple>; + using D1Layout = std::tuple_element_t<3, Tuple>; + using ELayout = std::tuple_element_t<4, Tuple>; + using ADataType = std::tuple_element_t<5, Tuple>; + using BDataType = std::tuple_element_t<6, Tuple>; + using D0DataType = std::tuple_element_t<7, Tuple>; + using D1DataType = std::tuple_element_t<8, Tuple>; + using AccDataType = std::tuple_element_t<9, Tuple>; + using EDataType = std::tuple_element_t<10, Tuple>; + using CDElementWiseFn = std::tuple_element_t<11, Tuple>; + using DsLayout = ck_tile::tuple; + using DsDataType = ck_tile::tuple; + + template + void invoke_gemm_multi_d(const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) + { + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" + << tail_num << "\" which is not supported! PrefetchStages: " + << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + public: + void Run(const int M, + const int N, + const int K, + const int k_batch, + int StrideA = 0, + int StrideB = 0, + int StrideD0 = 0, + int StrideD1 = 0, + int StrideE = 0) + { + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{}); + StrideD1 = f_get_default_stride(M, N, StrideD1, D1Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + + ck_tile::HostTensor a_m_k_tesnor( + f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + ck_tile::HostTensor b_k_n_tensors( + f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + ck_tile::HostTensor d0_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + ck_tile::HostTensor d1_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + ck_tile::HostTensor e_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_tesnor); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k_tesnor.mData.data()); + b_k_n_dev_buf.ToDevice(b_k_n_tensors.mData.data()); + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + std::array stridesDs = {StrideD0, StrideD1}; + + ck_tile::GemmHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + k_batch, + M, + N, + K, + StrideA, + StrideB, + stridesDs, + StrideE}); + + invoke_gemm_multi_d(args, ck_tile::stream_config{nullptr, false}); + + std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideE =" << StrideE + << " StrideD0 =" << StrideD0 << " StrideD1 =" << StrideD1 << std::endl; + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + bool pass = true; + + ck_tile::HostTensor e_m_n_host_ref( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_d( + a_m_k_tesnor, b_k_n_tensors, {d0_m_n_tensors, d1_m_n_tensors}, e_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + K, k_batch, max_accumulated_value); + pass = ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + + EXPECT_TRUE(pass); + } +}; diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 382a32a7d9..54f772f89e 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" template class TestCkTileGroupedGemm : public ::testing::Test @@ -23,6 +24,8 @@ class TestCkTileGroupedGemm : public ::testing::Test using BDataType = std::tuple_element_t<4, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; // Get the persistent value from ck_tile::bool_constant using PersistentType = std::tuple_element_t<7, Tuple>; @@ -48,7 +51,7 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 16; }; - using grouped_gemm_kargs = ck_tile::GemmHostArgs; + using grouped_gemm_kargs = ck_tile::GemmHostArgs; std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); @@ -127,9 +130,12 @@ class TestCkTileGroupedGemm : public ::testing::Test using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblemGetDeviceBuffer(); gemm_descs.push_back( - {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + {p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]}); } ck_tile::DeviceMem gemm_workspace; @@ -442,16 +451,18 @@ class TestCkTileGroupedGemm : public ::testing::Test const bool splitk = gemm_descs[0].k_batch > 1; for(const auto& arg : gemm_descs) { - kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.c_ptr, - arg.M, - arg.N, - arg.K, - arg.stride_A, - arg.stride_B, - arg.stride_C, - arg.k_batch}); + kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr, + arg.b_ptr, + {}, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.stride_A, + arg.stride_B, + {}, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, false, 1}; ck_tile::hip_check_error( From a0f4db8d9cb730d15ea32d3c6ede3feb409d8adf Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 13 Jun 2025 13:34:22 -0700 Subject: [PATCH 006/103] check for if misched-bottomup flag is valid (#2341) --- .../65_gemm_multiply_multiply/CMakeLists.txt | 8 +++++++- .../gpu/gemm_blockscale_wp/CMakeLists.txt | 19 +++++++++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 36f1860e4f..b9748aabda 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -43,7 +43,13 @@ endforeach() set(GEMM_OPTIONS) list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") set(BLOCKSCALE_GEMM_OPTIONS) -list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1") +check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) +check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) +if(HAS_MISCHED_BOTTOMUP) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1") +elseif(HAS_MISCHED_PRERA_DIRECTION) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup") +endif() check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental " HAS_MAX_OCCUPANCY_EXPERIMENTAL) if(HAS_MAX_OCCUPANCY_EXPERIMENTAL) list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental) diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt index 57cbd725aa..c8740e8d8c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt @@ -7,10 +7,17 @@ list(APPEND GEMM_BLOCKSCALE_WP_INSTANCES device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp ) - -set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") -set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") -set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") -set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") - +check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) +check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) +if(HAS_MISCHED_BOTTOMUP) + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-bottomup=1") +elseif(HAS_MISCHED_PRERA_DIRECTION) + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-prera-direction=bottomup") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-prera-direction=bottomup") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-prera-direction=bottomup") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--slp-threshold=-32;-mllvm;--misched-prera-direction=bottomup") +endif() add_instance_library(device_gemm_blockscale_wp_instance ${GEMM_BLOCKSCALE_WP_INSTANCES}) From 56f654a826b4794402e69675185af0bf3b98401b Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 13 Jun 2025 14:13:07 -0700 Subject: [PATCH 007/103] Limit the threads to builf ck_tile engine, use ninja. (#2342) * limit the threads to builf ck_tile engine, use ninja * disable ck_tile engine until it can be built safely --- Jenkinsfile | 18 +++++++++++++----- script/cmake-ck-dev.sh | 2 +- script/cmake-ck-release.sh | 2 +- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 1cb1a6ca6c..f9d7feb77c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -793,7 +793,7 @@ def process_results(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=false 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true @@ -1185,8 +1185,12 @@ pipeline { agent{ label rocmnode("gfx90a") } environment{ setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make benchmark_gemm -j && \ + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx90a" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j64 benchmark_gemm && \ ./bin/benchmark_gemm """ } steps{ @@ -1203,8 +1207,12 @@ pipeline { agent{ label rocmnode("gfx942") } environment{ setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ - make benchmark_gemm -j && \ + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx942" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j128 benchmark_gemm && \ ./bin/benchmark_gemm """ } steps{ diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 0e57af7aef..4d0836af39 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -16,7 +16,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index 95b1bebca5..acb04ac75f 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -16,7 +16,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -D CMAKE_CXX_FLAGS="-O3" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=OFF \ From 2d8a804152ebaa36775fea393227cb956e6e550e Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Sun, 15 Jun 2025 15:22:34 -0700 Subject: [PATCH 008/103] Fix direct lds load for gfx950 and clang20 (#2346) * fix direct lds load for gfx950 and clang20 * Update include/ck/utility/amd_buffer_addressing_builtins.hpp * Fix format --------- Co-authored-by: Aviral Goel Co-authored-by: Andriy Roshchenko --- .../utility/amd_buffer_addressing_builtins.hpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/include/ck/utility/amd_buffer_addressing_builtins.hpp b/include/ck/utility/amd_buffer_addressing_builtins.hpp index 1836e9461d..f642e06050 100644 --- a/include/ck/utility/amd_buffer_addressing_builtins.hpp +++ b/include/ck/utility/amd_buffer_addressing_builtins.hpp @@ -402,7 +402,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; @@ -838,10 +838,18 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, const bool is_valid, const index_t src_element_space_size) { - // Direct loads require that each thread reads and writes exactly a single DWORD. - constexpr auto dword_bytes = 4; + // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes). + // For gfx950: supports 1, 3, or 4 DWORDs per thread + // For gfx942: supports exactly 1 DWORD per thread constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#if defined(__gfx950__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#elif defined(__gfx942__) + constexpr auto dword_bytes = 4; static_assert(bytes_per_thread == dword_bytes); +#endif const int32x4_t src_resource = make_wave_buffer_resource(global_base_ptr, src_element_space_size); @@ -872,7 +880,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, #endif llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } #endif From fb97f75099bae6778adc8f41e20df184c416f83e Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 16 Jun 2025 13:49:04 +0800 Subject: [PATCH 009/103] hot fix block_gemm fail with pipeline_problem by adding NumWaveGroups inside block gemm problem (#2348) --- include/ck_tile/ops/gemm/block/block_gemm_problem.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index d8f66c81ca..fd5211a59a 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -12,7 +12,8 @@ template + typename BlockGemmShape_, + index_t NumWaveGroups_ = 1> struct BlockGemmProblem { using ADataType = remove_cvref_t; @@ -20,7 +21,8 @@ struct BlockGemmProblem using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t NumWaveGroups = NumWaveGroups_; }; } // namespace ck_tile From b34c234f5144d4ebd16ca04a379c907854d087ff Mon Sep 17 00:00:00 2001 From: ruanjm Date: Mon, 16 Jun 2025 17:17:03 +0800 Subject: [PATCH 010/103] Add support for specifying valid flag when fetching elements for tile_scatter_gather (#2332) * Add support for specifying valid flag when fetching elements for tile_scatter_gather Add constexpr for operator[] of TrueGenerator * Use different path when valid is enabled --- .../core/tensor/tile_scatter_gather.hpp | 167 +++++++++++++++--- 1 file changed, 147 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 351737d4d9..c7811133d6 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -33,6 +33,7 @@ template @@ -42,6 +43,7 @@ struct tile_scatter_gather using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; using PageIdxArray = remove_cvref_t; + using ValidArray = remove_cvref_t; using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; using BottomTensorDesc = typename BottomTensorView::TensorDesc; @@ -152,12 +154,14 @@ struct tile_scatter_gather const WindowLengths& window_lengths, const BottomTensorIndex& window_origin, const TileDstr& tile_distribution, - const PageIdxArray& page_idx) + const PageIdxArray& page_idx, + const ValidArray& valids) : bottom_tensor_view_{bottom_tensor_view}, window_lengths_{window_lengths}, window_origin_{window_origin}, tile_dstr_{tile_distribution}, page_idx_{page_idx}, + valids_{valids}, pre_computed_coords_{} { #if 0 // debug @@ -336,12 +340,25 @@ struct tile_scatter_gather constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_gather = idx_ys_start[number{}]; const auto page_offset = page_idx_[idx_gather]; + // read from bottom tensor - const vector_t vec_value = - get_bottom_tensor_view().template get_vectorized_elements( - bottom_tensor_thread_coord, - page_offset, - bool_constant{}); + const vector_t vec_value = [&]() { + if constexpr(std::is_same_v) + { + return get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + bool_constant{}); + } + else + { + return get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + valids_[idx_gather], + bool_constant{}); + } + }(); #if 1 // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { @@ -451,9 +468,23 @@ struct tile_scatter_gather constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_gather = idx_ys_start[number{}]; const auto page_offset = page_idx_[idx_gather]; + // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements_raw( - smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); + if constexpr(std::is_same_v) + { + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); + } + else + { + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, + bottom_tensor_thread_coord, + page_offset, + valids_[idx_gather], + 0, + pre_nop_); + } // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -529,11 +560,24 @@ struct tile_scatter_gather // const vector_t vec_value = vec.template get_as().template at<0>(); // write into bottom tensor - get_bottom_tensor_view().template set_vectorized_elements( - bottom_tensor_thread_coord, - page_offset, - vec_value, - bool_constant{}); + if constexpr(std::is_same_v) + { + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + vec_value, + bool_constant{}); + } + else + { + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + valids_[idx_gather], + vec_value, + bool_constant{}); + } + // printf("coord_offset:%d, scatter_offset:%d \n", // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -570,14 +614,23 @@ struct tile_scatter_gather }); } - CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) - { - page_idx_ = new_idx; + CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; } - // static_for<0, 2, 1>{}([&](auto k0) { - // printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]); - // }); + CK_TILE_DEVICE void update_valids(const ValidArray& new_valids) + { + if constexpr(std::is_same_v == false) + { + valids_ = new_valids; + } } + + CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray& new_idx, + const ValidArray& new_valids) + { + update_page_idx(new_idx); + update_valids(new_valids); + } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) { window_origin_ = new_window_origin; @@ -657,6 +710,7 @@ struct tile_scatter_gather TileDstr tile_dstr_; PageIdxArray page_idx_; + ValidArray valids_; // this contains: // per-thread coordinate for window adaptor @@ -684,9 +738,10 @@ make_tile_scatter_gather(const TensorView_& tensor_view, remove_cvref_t, remove_cvref_t, remove_cvref_t, + std::nullptr_t, HsGatherDim, NumCoord>{ - tensor_view, window_lengths, origin, tile_distribution, page_idx}; + tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; } template {}); } +template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + const StaticValidArray_& valids, + number = {}, + number = {}) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + HsGatherDim, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution, page_idx, valids}; +} + +template +CK_TILE_DEVICE constexpr auto make_tile_scatter_gather( + const tile_window_with_static_lengths& tile_window, + const multi_index& origin, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx, + const StaticValidArray& valids, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + origin, + tile_distribution, + page_idx, + valids, + number{}); +} + +template +CK_TILE_DEVICE constexpr auto make_tile_scatter_gather( + const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx, + const StaticValidArray& valids, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + page_idx, + valids, + number{}); +} + } // namespace ck_tile From d996bc78befb15ee0405ff78d0ad0da00f8550f3 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 16 Jun 2025 02:17:53 -0700 Subject: [PATCH 011/103] fix the flatmm (#2349) --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 3 +++ include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp | 3 ++- include/ck_tile/ops/gemm.hpp | 2 +- script/run_ck_profiler_gemm_with_csv_shapes.py | 4 ++-- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index c564d7d1b1..8782d2bb6a 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -49,9 +49,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index a9ed1519e6..d2e1bde58f 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -447,6 +447,7 @@ struct FlatmmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_flat_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr); @@ -454,7 +455,7 @@ struct FlatmmKernel auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr); + c_block_window, c_block_tile, d_block_window, smem_ptr); } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 8db822ebd1..a1d37f0824 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,8 +31,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/script/run_ck_profiler_gemm_with_csv_shapes.py b/script/run_ck_profiler_gemm_with_csv_shapes.py index 1f7ec7585f..54b4b337de 100644 --- a/script/run_ck_profiler_gemm_with_csv_shapes.py +++ b/script/run_ck_profiler_gemm_with_csv_shapes.py @@ -278,13 +278,13 @@ def main(): shapes = tuples(filename) all_results = [] - from tqdm import tqdm from functools import partial from os import path profiler_bin = path.join(args["build_dir"], "bin", "ckProfiler") - for s in tqdm(shapes): + total = len(shapes) + for idx, s in enumerate(shapes, 1): run_shape_stdout_lines = run_shape( s, profiler_bin, args["op_name"], args["dtype"], args["layout"] ) From f6c2ff9dcedbc58065ae1fc10a661f00716c6839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 16 Jun 2025 15:36:53 +0200 Subject: [PATCH 012/103] Grouped convolution forward with clamp (#2334) * Grouped convolution forward with clamp * Optimize clamp * unary fixes * test gk bias * Revert "test gk bias" This reverts commit 8e42e29d7b64dfa12d15bb85932ce9dd0f334065. * Revert "Revert "test gk bias"" This reverts commit e73c0550ce840f6013580722fb6426df1bbaf17b. * workaround comment --- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 11 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 22 +- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 5 +- .../element/unary_element_wise_operation.hpp | 179 +++++++++++++ .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 95 ++++--- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 143 +++++++---- .../device_operation_instance_factory.hpp | 1 + ...ice_grouped_conv_fwd_xdl_comp_instance.hpp | 1 + .../device_grouped_conv_fwd_xdl_instance.hpp | 1 + ...ped_conv_fwd_xdl_large_tensor_instance.hpp | 1 + ...vice_grouped_conv_fwd_xdl_mem_instance.hpp | 1 + ...ed_conv_fwd_xdl_merged_groups_instance.hpp | 1 + .../gpu/grouped_convolution_forward_clamp.hpp | 140 ++++++++++ .../grouped_convolution_forward_clamp_xdl.inc | 242 ++++++++++++++++++ .../grouped_conv2d_fwd_clamp/CMakeLists.txt | 16 ++ ...hwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp | 67 +++++ ...l_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp | 61 +++++ ...c_gkyxc_nhwgk_bf16_comp_part2_instance.cpp | 67 +++++ ..._nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp | 60 +++++ ...mp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 60 +++++ ...tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 41 +++ ...gc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp | 63 +++++ ...gc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp | 63 +++++ ...groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 80 ++++++ .../grouped_conv3d_fwd_clamp/CMakeLists.txt | 16 ++ ...dhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp | 127 +++++++++ ...hwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp | 58 +++++ ...xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 58 +++++ ...sor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 41 +++ ..._gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp | 61 +++++ ..._gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp | 61 +++++ ...ups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 51 ++++ ...ofile_grouped_conv_fwd_bias_clamp_impl.hpp | 51 +++- .../profile_grouped_conv_fwd_impl.hpp | 9 +- script/convert_miopen_driver_to_profiler.py | 48 ++++ test/CMakeLists.txt | 2 +- .../CMakeLists.txt | 10 + .../test_grouped_convnd_fwd_bias_clamp.cpp | 3 +- .../test_grouped_convnd_fwd_clamp.cpp | 95 +++++++ .../test_grouped_convnd_fwd_gk_bias_clamp.cpp | 93 +++++++ .../CMakeLists.txt | 4 - 41 files changed, 2103 insertions(+), 106 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 test/grouped_convnd_fwd_activation/CMakeLists.txt rename test/{grouped_convnd_fwd_bias_clamp => grouped_convnd_fwd_activation}/test_grouped_convnd_fwd_bias_clamp.cpp (96%) create mode 100644 test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp create mode 100644 test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp delete mode 100644 test/grouped_convnd_fwd_bias_clamp/CMakeLists.txt diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 27da1d91a3..6d04835b21 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -311,8 +311,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static_assert(NumGroupsToMerge >= 1); - static constexpr bool isMultiA = is_detected::value; - static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiAB = isMultiA || isMultiB; // NGCHW is not supported for multiAB static_assert(!(is_NGCHW_NGKHW() || @@ -323,6 +324,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr index_t NumBTensor = GetNumABTensors(); static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr bool DoElementwiseBeforeCShuffle = + NumDTensor == 0 && !isMultiAB && is_same_v && + !is_same_v; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -465,7 +470,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ - BComputeDataType + BComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm using GridwiseGemm = std::conditional_t< isMultiA || isMultiB, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index bebcd72ceb..48424c16b9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -279,6 +279,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr bool isMultiD = DsDataType::Size() > 0; static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; + static constexpr bool DoElementwiseBeforeCShuffle = + !isMultiABD && is_same_v && + !is_same_v; + static constexpr index_t NumATensor = GetNumABTensors(); static constexpr index_t NumBTensor = GetNumABTensors(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -412,7 +416,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ - AComputeDataType, BComputeDataType + AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3; @@ -780,8 +784,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 sizeof(EDataType); } - typename GridwiseGemm::Argument gemm_arg{ - p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1}; + typename GridwiseGemm::Argument gemm_arg{p_a_grid, + p_b_grid, + p_e_grid, + GemmM, + GemmN, + GemmK, + I0, + I0, + I0, + I1, + false, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 94a4e0da4c..9988367959 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -192,6 +192,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t MaxGemmsNum = 32; + static constexpr bool DoElementwiseBeforeCShuffle = + NumDTensor == 0 && is_same_v && + !is_same_v; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -361,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ - AComputeDataType + AComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 047ff3bd06..8f829496da 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -730,6 +730,15 @@ struct UnaryAbs { y = ck::type_convert(ck::math::abs(ck::type_convert(x))); }; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = ck::type_convert(ck::math::abs(x)); + }; }; struct UnarySqrt @@ -744,6 +753,79 @@ struct UnarySqrt }; }; +struct Clamp +{ + Clamp(float floor = 0.f, float ceil = NumericLimits::Max()) + : floor_(floor), ceil_(ceil){}; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ constexpr void operator()(float& y, const float& x) const + { + const float& a = x; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void operator()(double& y, const double& x) const + { + const double& a = x; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void operator()(half_t& y, const half_t& x) const + { + const float a = type_convert(x); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void operator()(half_t& y, const float& x) const + { + const float& a = x; + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void operator()(bhalf_t& y, const float& x) const + { + const float& a = x; + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void operator()(bhalf_t& y, + const bhalf_t& x) const + { + const float a = type_convert(x); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void operator()(int& y, const int& x) const + { + const int8_t& a = x; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void operator()(int8_t& y, const int8_t& x) const + { + const int8_t& a = x; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + const float floor_; + const float ceil_; +}; + struct Relu { template @@ -756,6 +838,9 @@ struct Relu y = x > 0 ? x : 0; } + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { @@ -763,6 +848,13 @@ struct Relu float y_f32 = x_f32 > 0 ? x_f32 : 0; y = type_convert(y_f32); } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + float y_f32 = x > 0 ? x : 0; + y = type_convert(y_f32); + }; }; // Fast GeLU @@ -915,6 +1007,16 @@ struct Sigmoid constexpr T one = type_convert(1); y = one / (one + math::exp(-x)); }; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + constexpr float one = 1.f; + y = type_convert(one / (one + math::exp(-x))); + }; }; struct Silu @@ -942,6 +1044,15 @@ struct TanH y = math::tanh(x); }; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(math::tanh(x)); + }; }; struct ACos @@ -1201,6 +1312,13 @@ struct Swish y = type_convert(x / (1.f + math::exp(bx))); }; + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + float bx = -beta_ * x; + y = type_convert(x / (1.f + math::exp(bx))); + }; + const float beta_; }; @@ -1219,6 +1337,16 @@ struct SoftRelu constexpr T one = type_convert(1); y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha; } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + constexpr float one = 1.f; + y = type_convert(math::log(one + math::exp(x * alpha_)) / alpha_); + }; const float alpha_; }; @@ -1240,6 +1368,17 @@ struct Power T shifted_scaled_x = casted_alpha + casted_beta * x; y = math::pow(shifted_scaled_x, casted_gamma); } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + const float shifted_scaled_x = alpha_ + beta_ * x; + y = type_convert(math::pow(shifted_scaled_x, gamma_)); + }; + const float alpha_; const float beta_; const float gamma_; @@ -1260,6 +1399,16 @@ struct ClippedRelu T casted_beta = type_convert(beta_); y = math::min(casted_beta, math::max(casted_alpha, x)); } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(math::min(beta_, math::max(alpha_, x))); + }; + const float alpha_; const float beta_; }; @@ -1278,6 +1427,16 @@ struct LeakyRelu T casted_alpha = type_convert(alpha_); y = x >= 0 ? x : x * casted_alpha; } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(x >= 0 ? x : x * alpha_); + }; + const float alpha_; }; @@ -1295,6 +1454,16 @@ struct Elu T casted_alpha = type_convert(alpha_); y = x > 0 ? x : casted_alpha * math::expm1(x); } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(x > 0 ? x : alpha_ * math::expm1(x)); + }; + const float alpha_; }; @@ -1313,6 +1482,16 @@ struct Logistic constexpr T one = type_convert(1); y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + constexpr float one = 1.f; + y = type_convert(alpha_ / (one + ck::math::exp(-x) * alpha_)); + }; const float alpha_; }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index be0fff087e..acbccf1889 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -71,11 +71,13 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename BComputeDataType_ = AComputeDataType_, + bool DoElementwiseBeforeCShuffle = false> struct GridwiseGemmMultipleD_xdl_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); + static_assert(!DoElementwiseBeforeCShuffle || NumDTensor == 0); using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; @@ -796,37 +798,60 @@ struct GridwiseGemmMultipleD_xdl_cshuffle n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( @@ -860,7 +885,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Tuple, decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CDEElementwiseOperation, + conditional_t, Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence // support arbitray type Sequence<1, @@ -881,7 +908,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), - cde_element_op}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR before shuffle constexpr auto sfc_c_vgpr = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 338674ae85..6270d0c4dc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -186,6 +186,8 @@ __global__ void /// in global memory. Currently not supported! /// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout /// in global memory (pre-shuffled). +/// @tparam DoElementwiseBeforeCShuffle Whether the cde_elementwise should be performed before or +/// after elementwise op. template + bool PermuteB = false, + bool DoElementwiseBeforeCShuffle = false> struct GridwiseGemm_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; @@ -636,7 +639,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 index_t StrideA_, index_t StrideB_, index_t StrideC_, - index_t KBatch_) + index_t KBatch_, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) : M{M_}, N{N_}, K{K_}, @@ -651,7 +657,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)} + NBlock{CalculateNBlock(N_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} { } @@ -689,6 +698,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 index_t BK0; index_t MBlock; index_t NBlock; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; }; // Argument @@ -704,8 +716,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 index_t StrideB_, index_t StrideC_, index_t k_batch_, - bool is_reduce_ = false) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + bool is_reduce_ = false, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CElementwiseOperation c_element_op = CElementwiseOperation{}) + : Problem{M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + k_batch_, + a_element_op, + b_element_op, + c_element_op}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, p_c_grid{p_c_grid_}, @@ -1377,10 +1401,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; - // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -1440,7 +1460,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, + problem.a_element_op_, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); @@ -1471,7 +1491,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, + problem.b_element_op_, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); @@ -1598,42 +1618,67 @@ struct GridwiseGemm_xdl_cshuffle_v3 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return problem.c_element_op_; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return problem.c_element_op_; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, + ThisThreadBlock, // ThreadGroup + conditional_t, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, @@ -1654,7 +1699,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_multi_index(0, 0, 0, 0), c_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = @@ -1773,10 +1818,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; - // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -1836,7 +1877,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, + problem.a_element_op_, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); @@ -1867,7 +1908,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, + problem.b_element_op_, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); @@ -2059,7 +2100,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_multi_index(0, 0, 0, 0), c_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; + problem.c_element_op_}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 274273d576..022afe7fa4 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -121,6 +121,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; using AddRelu = ck::tensor_operation::element_wise::AddRelu; using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; using AddSilu = ck::tensor_operation::element_wise::AddSilu; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using FastGelu = ck::tensor_operation::element_wise::FastGelu; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index 3fbf2fbc7b..fca236d03e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -34,6 +34,7 @@ using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 7311f4bf75..d6b695360b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -34,6 +34,7 @@ using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp index 5a4d0338b0..3e98852d58 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp @@ -26,6 +26,7 @@ using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp index 6da3ee1a4f..4e6b9c3d1d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp @@ -34,6 +34,7 @@ using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index d074988a22..7ef78d46e2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -26,6 +26,7 @@ using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp new file mode 100644 index 0000000000..cb84ca6130 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +#ifdef CK_USE_XDL +#include "grouped_convolution_forward_clamp_xdl.inc" +#endif + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_USE_XDL + // layout NHWGC/GKYXC/NHWGK + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + op_ptrs); + } +#endif + } + // layout NDHWGC/GKZYXC/NDHWGK + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + op_ptrs); + } +#endif + } +#endif // CK_USE_XDL + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc new file mode 100644 index 0000000000..b943bf728f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt new file mode 100644 index 0000000000..15d236525b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt @@ -0,0 +1,16 @@ +# ONLY XDL_KERNELS +add_instance_library(device_grouped_conv2d_fwd_clamp_instance + xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp + + xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + + xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + + xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp + + xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp new file mode 100644 index 0000000000..d770bdc24e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp new file mode 100644 index 0000000000..ade9b466ac --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp new file mode 100644 index 0000000000..5abab15254 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..61c84fcb29 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..f766db04c9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..45a84fd814 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000..42c82c3c1a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Interwave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Interwave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..52fc9ed765 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Intrawave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..1156375655 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd3x3, + Tuple<>, + Clamp>{}); + } + else + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd3x3, + Tuple<>, + Clamp>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt new file mode 100644 index 0000000000..5eb0dd50eb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -0,0 +1,16 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp + + xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + + xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + + xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp + + xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp +) + +add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp new file mode 100644 index 0000000000..5293fa70c3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..a454671a52 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..9bc9c1c786 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..f35d6b3307 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000..c706ae4d7a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Interwave, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Interwave, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..d6c4bcc417 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Intrawave, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..d0f2a16c8a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd3x3, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index 3ef9f4505d..c12fa75e34 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -25,6 +25,28 @@ namespace ck { namespace profiler { +// NOTE: Usage of NHWGK layout for GK bias is a workaround. This test is to +// just keep such implementation valid. +// TODO: Add possiblity to pass GK layout and GK lengths for bias and reuse +// the same instances. + +template +auto get_bias_desc(ck::index_t G, ck::index_t K) +{ + if constexpr(NDimSpatial == 1) + { + return HostTensorDescriptor({G, 1, K, 1}, {K, 0, 1, 0}); + } + else if constexpr(NDimSpatial == 2) + { + return HostTensorDescriptor({G, 1, K, 1, 1}, {K, 0, 1, 0, 0}); + } + else + { + return HostTensorDescriptor({G, 1, K, 1, 1, 1}, {K, 0, 1, 0, 0, 0}); + } +} + template + typename IndexType = ck::index_t, + bool BiasGK = false> bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, int init_method, bool do_log, @@ -61,12 +84,16 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + const index_t G = conv_param.G_; + const index_t K = conv_param.K_; + std::array a_g_n_c_wis_lengths{}; std::array a_g_n_c_wis_strides{}; std::array b_g_k_c_xs_lengths{}; std::array b_g_k_c_xs_strides{}; std::array e_g_n_k_wos_lengths{}; std::array e_g_n_k_wos_strides{}; + std::array d_g_n_k_wos_strides{}; std::array conv_filter_strides{}; std::array conv_filter_dilations{}; std::array input_left_pads{}; @@ -80,6 +107,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.input_left_pads_, input_left_pads); @@ -89,7 +117,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, Tensor weight(wei_g_k_c_xs_desc); Tensor host_output(out_g_n_k_wos_desc); Tensor device_output(out_g_n_k_wos_desc); - Tensor bias(out_g_n_k_wos_desc); + const auto bias_desc = BiasGK ? get_bias_desc(G, K) : out_g_n_k_wos_desc; + Tensor bias(bias_desc); std::cout << "input: " << input.mDesc << std::endl; std::cout << "weight: " << weight.mDesc << std::endl; @@ -113,7 +142,11 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); - DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize()); + + const std::size_t bias_dev_buf_size = + BiasGK ? sizeof(OutDataType) * G * K + : sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize(); + DeviceMem bias_device_buf(bias_dev_buf_size); in_device_buf.ToDevice(input.mData.data()); wei_device_buf.ToDevice(weight.mData.data()); @@ -244,6 +277,16 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; + if constexpr(BiasGK) + { + constexpr ck::index_t spatial_offset = 3; + d_g_n_k_wos_strides[1] = 0; + for(int i = 0; i < NDimSpatial; i++) + { + d_g_n_k_wos_strides[i + spatial_offset] = 0; + } + } + for(auto& op_ptr : op_ptrs) { auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(), @@ -255,7 +298,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, {e_g_n_k_wos_lengths}, - {e_g_n_k_wos_strides}, + {d_g_n_k_wos_strides}, e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_filter_strides, diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index 08e707b665..a1f9ee1528 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -34,20 +35,20 @@ template + typename IndexType = ck::index_t, + typename OutElementOp = ck::tensor_operation::element_wise::PassThrough> bool profile_grouped_conv_fwd_impl(int do_verification, int init_method, bool do_log, bool time_kernel, - const ck::utils::conv::ConvParam& conv_param) + const ck::utils::conv::ConvParam& conv_param, + const OutElementOp out_element_op = OutElementOp{}) { using InElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; - using OutElementOp = ck::tensor_operation::element_wise::PassThrough; const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{}; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py index 2ddcbb67cd..9e2f436e68 100644 --- a/script/convert_miopen_driver_to_profiler.py +++ b/script/convert_miopen_driver_to_profiler.py @@ -208,6 +208,8 @@ if __name__ == "__main__": parser.add_argument( "-in_layout", "-I", + "--in_layout", + "--I", default="NCHW", type=str, required=False, @@ -216,6 +218,8 @@ if __name__ == "__main__": parser.add_argument( "-forw", "-F", + "--forw", + "--F", default=0, type=int, required=False, @@ -231,6 +235,8 @@ if __name__ == "__main__": parser.add_argument( "-spatial_dim", "-_", + "--spatial_dim", + "--_", default=2, type=int, required=False, @@ -239,6 +245,8 @@ if __name__ == "__main__": parser.add_argument( "-batchsize", "-n", + "--batchsize", + "--n", default=100, type=int, required=False, @@ -247,6 +255,8 @@ if __name__ == "__main__": parser.add_argument( "-in_channels", "-c", + "--in_channels", + "--c", default=3, type=int, required=False, @@ -255,6 +265,8 @@ if __name__ == "__main__": parser.add_argument( "-in_d", "-!", + "--in_d", + "--!", default=32, type=int, required=False, @@ -263,6 +275,8 @@ if __name__ == "__main__": parser.add_argument( "-in_h", "-H", + "--in_h", + "--H", default=32, type=int, required=False, @@ -271,6 +285,8 @@ if __name__ == "__main__": parser.add_argument( "-in_w", "-W", + "--in_w", + "--W", default=32, type=int, required=False, @@ -279,6 +295,8 @@ if __name__ == "__main__": parser.add_argument( "-out_channels", "-k", + "--out_channels", + "--k", default=32, type=int, required=False, @@ -287,6 +305,8 @@ if __name__ == "__main__": parser.add_argument( "-fil_d", "-@", + "--fil_d", + "--@", default=3, type=int, required=False, @@ -295,6 +315,8 @@ if __name__ == "__main__": parser.add_argument( "-fil_h", "-y", + "--fil_h", + "--y", default=3, type=int, required=False, @@ -303,6 +325,8 @@ if __name__ == "__main__": parser.add_argument( "-fil_w", "-x", + "--fil_w", + "--x", default=3, type=int, required=False, @@ -311,6 +335,8 @@ if __name__ == "__main__": parser.add_argument( "-conv_stride_d", "-#", + "--conv_stride_d", + "--#", default=1, type=int, required=False, @@ -319,6 +345,8 @@ if __name__ == "__main__": parser.add_argument( "-conv_stride_h", "-u", + "--conv_stride_h", + "--u", default=1, type=int, required=False, @@ -327,6 +355,8 @@ if __name__ == "__main__": parser.add_argument( "-conv_stride_w", "-v", + "--conv_stride_w", + "--v", default=1, type=int, required=False, @@ -335,6 +365,8 @@ if __name__ == "__main__": parser.add_argument( "-pad_d", "-$", + "--pad_d", + "--$", default=1, type=int, required=False, @@ -343,6 +375,8 @@ if __name__ == "__main__": parser.add_argument( "-pad_h", "-p", + "--pad_h", + "--p", default=1, type=int, required=False, @@ -351,6 +385,8 @@ if __name__ == "__main__": parser.add_argument( "-pad_w", "-q", + "--pad_w", + "--q", default=1, type=int, required=False, @@ -359,6 +395,8 @@ if __name__ == "__main__": parser.add_argument( "-verify", "-V", + "--verify", + "--V", default=1, type=int, required=False, @@ -367,6 +405,8 @@ if __name__ == "__main__": parser.add_argument( "-time", "-t", + "--time", + "--t", default=0, type=int, required=False, @@ -375,6 +415,8 @@ if __name__ == "__main__": parser.add_argument( "-dilation_d", "-^", + "--dilation_d", + "--^", default=1, type=int, required=False, @@ -383,6 +425,8 @@ if __name__ == "__main__": parser.add_argument( "-dilation_h", "-l", + "--dilation_h", + "--l", default=1, type=int, required=False, @@ -391,6 +435,8 @@ if __name__ == "__main__": parser.add_argument( "-dilation_w", "-j", + "--dilation_w", + "--j", default=1, type=int, required=False, @@ -399,6 +445,8 @@ if __name__ == "__main__": parser.add_argument( "-group_count", "-g", + "--group_count", + "--g", type=int, default=1, required=False, diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1f2e7022ba..5b25550d9b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -252,7 +252,7 @@ add_subdirectory(reduce) add_subdirectory(convnd_fwd) add_subdirectory(convnd_bwd_data) add_subdirectory(grouped_convnd_fwd) -add_subdirectory(grouped_convnd_fwd_bias_clamp) +add_subdirectory(grouped_convnd_fwd_activation) add_subdirectory(grouped_convnd_bwd_weight) add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) diff --git a/test/grouped_convnd_fwd_activation/CMakeLists.txt b/test/grouped_convnd_fwd_activation/CMakeLists.txt new file mode 100644 index 0000000000..8bded647b6 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/CMakeLists.txt @@ -0,0 +1,10 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp) + target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) + + add_gtest_executable(test_grouped_convnd_fwd_gk_bias_clamp test_grouped_convnd_fwd_gk_bias_clamp.cpp) + target_link_libraries(test_grouped_convnd_fwd_gk_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) + + add_gtest_executable(test_grouped_convnd_fwd_clamp test_grouped_convnd_fwd_clamp.cpp) + target_link_libraries(test_grouped_convnd_fwd_clamp PRIVATE utility device_grouped_conv2d_fwd_clamp_instance device_grouped_conv3d_fwd_clamp_instance) +endif() diff --git a/test/grouped_convnd_fwd_bias_clamp/test_grouped_convnd_fwd_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp similarity index 96% rename from test/grouped_convnd_fwd_bias_clamp/test_grouped_convnd_fwd_bias_clamp.cpp rename to test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp index 7d5437d247..f3a569115a 100644 --- a/test/grouped_convnd_fwd_bias_clamp/test_grouped_convnd_fwd_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp @@ -41,7 +41,8 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, DataType, DataType, - IndexType>( + IndexType, + false /*BiasGK*/>( true, // do_verification 1, // init_method: integer value false, // do_log diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp new file mode 100644 index 0000000000..d3ede8671e --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_fwd_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using Clamp = ck::tensor_operation::element_wise::Clamp; + +template +class TestGroupedConvndFwd : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + Clamp out_element_op{0.f, 256.f}; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_fwd_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + out_element_op); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = ::testing::Types>; + +using KernelTypes3d = ::testing::Types>; + +template +class TestGroupedConvndFwd2d : public TestGroupedConvndFwd +{ +}; + +template +class TestGroupedConvndFwd3d : public TestGroupedConvndFwd +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwd2d, Test2D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndFwd3d, Test3D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp new file mode 100644 index 0000000000..0a41eac286 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using AddClamp = ck::tensor_operation::element_wise::AddClamp; + +template +class TestGroupedConvndFwd : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = ::testing::Types>; + +using KernelTypes3d = ::testing::Types>; + +template +class TestGroupedConvndFwd2d : public TestGroupedConvndFwd +{ +}; + +template +class TestGroupedConvndFwd3d : public TestGroupedConvndFwd +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwd2d, Test2D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndFwd3d, Test3D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_bias_clamp/CMakeLists.txt b/test/grouped_convnd_fwd_bias_clamp/CMakeLists.txt deleted file mode 100644 index 4630a37d33..0000000000 --- a/test/grouped_convnd_fwd_bias_clamp/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -if(GPU_TARGETS MATCHES "gfx9") - add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp) - target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) -endif() From 5523df4b2dfab16d6144d7717b3b075f8c6d5104 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 16 Jun 2025 07:54:55 -0700 Subject: [PATCH 013/103] Revert "fix the flatmm (#2349)" (#2352) This reverts commit d996bc78befb15ee0405ff78d0ad0da00f8550f3. --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 3 --- include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp | 3 +-- include/ck_tile/ops/gemm.hpp | 2 +- script/run_ck_profiler_gemm_with_csv_shapes.py | 4 ++-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 8782d2bb6a..c564d7d1b1 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -49,12 +49,9 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, - ck_tile::tuple<>, CLayout, - ck_tile::element_wise::PassThrough, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index d2e1bde58f..a9ed1519e6 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -447,7 +447,6 @@ struct FlatmmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr); @@ -455,7 +454,7 @@ struct FlatmmKernel auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr); + c_block_window, c_block_tile, smem_ptr); } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index a1d37f0824..8db822ebd1 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,8 +31,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/script/run_ck_profiler_gemm_with_csv_shapes.py b/script/run_ck_profiler_gemm_with_csv_shapes.py index 54b4b337de..1f7ec7585f 100644 --- a/script/run_ck_profiler_gemm_with_csv_shapes.py +++ b/script/run_ck_profiler_gemm_with_csv_shapes.py @@ -278,13 +278,13 @@ def main(): shapes = tuples(filename) all_results = [] + from tqdm import tqdm from functools import partial from os import path profiler_bin = path.join(args["build_dir"], "bin", "ckProfiler") - total = len(shapes) - for idx, s in enumerate(shapes, 1): + for s in tqdm(shapes): run_shape_stdout_lines = run_shape( s, profiler_bin, args["op_name"], args["dtype"], args["layout"] ) From 6589f50bc93ee3c4ccb7c8a6c765338284b9bc73 Mon Sep 17 00:00:00 2001 From: rahjain-amd Date: Mon, 16 Jun 2025 21:59:35 +0530 Subject: [PATCH 014/103] Add cmake flag to enable Assembly dump (#2347) This flag makes it easy to dump assembly for the example kernels. --- CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index aab74f3069..b0fc725236 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -308,6 +308,7 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) +option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -321,6 +322,12 @@ if(USE_OPT_GFX11) message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") endif() +if(ENABLE_ASM_DUMP) + add_compile_options(--save-temps) + add_compile_options(-Wno-gnu-line-marker) + message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) From 3c4cdfac4f6dd9c2f952a02acb028e2c3dd62ef9 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 16 Jun 2025 17:38:52 -0700 Subject: [PATCH 015/103] Fix the CK Tile related operators (#2356) * fix the flatmm * Fix the pipeline * address the comment --- example/ck_tile/03_gemm/gemm_basic.cpp | 3 +++ example/ck_tile/03_gemm/universal_gemm.cpp | 2 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 3 +++ include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp | 3 ++- include/ck_tile/ops/gemm.hpp | 2 +- .../ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp | 1 + .../gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp | 2 ++ .../ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp | 2 ++ include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp | 1 + script/run_ck_profiler_gemm_with_csv_shapes.py | 8 ++++++-- 10 files changed, 22 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index defeffc2ee..1906b0bda7 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -69,9 +69,12 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index beb6987605..3ec90e7f00 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -166,7 +166,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: // clear c mem if(args.k_batch > 1) hipGetErrorString(hipMemsetAsync( - args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; ave_time = ck_tile::launch_kernel_preprocess( s, diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index c564d7d1b1..8782d2bb6a 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -49,9 +49,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index a9ed1519e6..d2e1bde58f 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -447,6 +447,7 @@ struct FlatmmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_flat_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr); @@ -454,7 +455,7 @@ struct FlatmmKernel auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr); + c_block_window, c_block_tile, d_block_window, smem_ptr); } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 8db822ebd1..a1d37f0824 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,8 +31,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 9ef7f3f0ef..55220730cd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -1,5 +1,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 217408fffa..881467cb94 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -47,6 +47,8 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t kLdsAlignmentInBytes = 16; [[nodiscard]] CK_TILE_HOST static const std::string GetName() diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 678fb6eb46..b349991470 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -32,6 +32,8 @@ struct GemmPipelineProblemBase static constexpr bool TransposeC = Traits::TransposeC; + static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 353192d86f..c6f83068a9 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -28,6 +28,7 @@ struct TileGemmTraits static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; + static constexpr index_t NumWaveGroups = 1; }; template Date: Tue, 17 Jun 2025 10:07:08 -0400 Subject: [PATCH 016/103] add script to pre commit hooks for checking file permissions (#2322) --- .pre-commit-config.yaml | 6 ++++++ script/remove_exec_bit.sh | 8 ++++++++ 2 files changed, 14 insertions(+) create mode 100755 script/remove_exec_bit.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d6700ae05b..4dc70c1ffd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,3 +12,9 @@ repos: verbose: false language: script types: [c++] + - id: remove-exec-bit + name: Remove executable bit from non-executable files + entry: script/remove_exec_bit.sh + language: script + types_or: [c++, text] + verbose: true diff --git a/script/remove_exec_bit.sh b/script/remove_exec_bit.sh new file mode 100755 index 0000000000..25466d8c37 --- /dev/null +++ b/script/remove_exec_bit.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +for file in $(git diff --cached --name-only --diff-filter=ACM | grep -E '\.(cpp|hpp|txt|inc)$'); do + if [ -x "$file" ]; then + chmod -x "$file" + echo "[remove-exec-bit] Removed executable bit from $file" >&2 + fi +done From 4c57157d508e4c102626730aa372c8111670a878 Mon Sep 17 00:00:00 2001 From: Satyanvesh Dittakavi <53337087+satyanveshd@users.noreply.github.com> Date: Wed, 18 Jun 2025 00:24:30 +0530 Subject: [PATCH 017/103] Do not use warpSize as compile time constant as it is removed (#2320) * Do not use warpSize as compile time constant as it is removed * Update tile_image_to_column_shape.hpp update warpSize usage. * clean-up all use of warpSize, make sure code builds * fix --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin Co-authored-by: Bartlomiej Kocot --- example/ck_tile/02_layernorm2d/generate.py | 20 ++-- example/ck_tile/05_reduce/reduce.hpp | 2 +- example/ck_tile/10_rmsnorm2d/generate.py | 20 ++-- .../add_rmsnorm2d_rdquant_fwd.hpp | 20 ++-- .../ck_tile/12_smoothquant/smoothquant.hpp | 20 ++-- .../14_moe_smoothquant/moe_smoothquant.hpp | 20 ++-- include/ck/ck.hpp | 6 + ...blockwise_gemm_mx_pipeline_xdlops_base.hpp | 2 +- .../blockwise_gemm_pipeline_xdlops_base.hpp | 2 +- .../blockwise_gemm_pipeline_xdlops_v2.hpp | 4 +- ...kwise_gemm_pipeline_xdlops_v2_ab_scale.hpp | 2 +- ...ckwise_gemm_pipeline_xdlops_v2_b_scale.hpp | 4 +- .../gridwise_multiblock_batchnorm_forward.hpp | 2 +- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 6 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 6 +- ...fle_v3_multi_d_blockscale_b_preshuffle.hpp | 6 +- ...se_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 4 +- .../gpu/grid/gridwise_moe_gemm.hpp | 11 +- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 10 +- .../gpu/grid/gridwise_moe_mx_gemm.hpp | 10 +- .../gpu/grid/gridwise_moe_mx_gemm_bns.hpp | 4 +- .../ck/utility/workgroup_synchronization.hpp | 2 +- include/ck_tile/core/arch/utility.hpp | 2 +- .../flatmm_32x512x128_1x4x1_16x16x32.hpp | 26 ++--- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 38 +++---- .../fused_moe/kernel/fused_moegemm_shape.hpp | 2 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 106 +++++++++--------- .../fused_moegemm_pipeline_flatmm_policy.hpp | 52 ++++----- .../pipeline/tile_image_to_column_shape.hpp | 2 +- .../norm_reduce/block/block_norm_reduce.hpp | 4 +- .../ops/reduce/block/block_reduce2d.hpp | 4 +- 31 files changed, 213 insertions(+), 206 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 0238a125dc..2dc9ccbd77 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/example/ck_tile/05_reduce/reduce.hpp b/example/ck_tile/05_reduce/reduce.hpp index 55e479591c..50ffb9c1c7 100644 --- a/example/ck_tile/05_reduce/reduce.hpp +++ b/example/ck_tile/05_reduce/reduce.hpp @@ -35,7 +35,7 @@ struct Reduce2dShape static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); static constexpr index_t BlockSize = - warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + WarpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); }; template ; using UnquantYDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -97,13 +97,13 @@ struct rmsnorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index c91b387d62..1d843b5594 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -80,22 +80,22 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ using InputDataType = ck_tile::remove_cvref_t; using QuantizedDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -103,13 +103,13 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/example/ck_tile/12_smoothquant/smoothquant.hpp b/example/ck_tile/12_smoothquant/smoothquant.hpp index 83ad7b012c..265399c276 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.hpp +++ b/example/ck_tile/12_smoothquant/smoothquant.hpp @@ -49,22 +49,22 @@ struct smoothquant_traits_ { using DataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -72,13 +72,13 @@ struct smoothquant_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp index c1b90b14b2..b29295f175 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp @@ -38,22 +38,22 @@ struct moe_smoothquant_traits_ using InputType = ck_tile::remove_cvref_t; using OutputType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); } }(); @@ -61,13 +61,13 @@ struct moe_smoothquant_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(WarpSize % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; } }(); diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 26e4787949..3c1373a387 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -274,6 +274,12 @@ namespace ck { +#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) +__device__ static constexpr int WarpSize = 64; +#else +__device__ static constexpr int WarpSize = 32; +#endif + enum struct InMemoryDataOperationEnum { Set, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp index f366f309ff..5370cfa975 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -45,7 +45,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base using ThisThreadBlock = ThisThreadBlock; - // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs. + // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs. static constexpr index_t WaveSize = 64; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 94772361d3..9296b8136f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -40,7 +40,7 @@ struct BlockwiseGemmXdlops_pipeline_base using ThisThreadBlock = ThisThreadBlock; - // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs. + // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs. static constexpr index_t WaveSize = 64; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 54edf0c353..a6b5e272ff 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); @@ -631,7 +631,7 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index c8ad9c5b02..0c030030fe 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -143,7 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp index 776f66dbbb..69002d7962 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -141,7 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); @@ -632,7 +632,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * warpSize / BlockSize : 1; + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( 32768 / WgpPerCU, (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp index 47573107cf..7c9febf4de 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp @@ -202,7 +202,7 @@ struct GridwiseMultiblockBatchNormForward const index_t block_local_id = block_global_id % blkgroup_size; if(block_local_id == 0) - gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]); + gms_init(BlockSize / WarpSize * blkgroup_size, &p_control[blkgroup_id * 2]); const auto thread_cluster_idx = thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index cfa8bfeb2a..8d5c844103 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -347,7 +347,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1229,7 +1229,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1607,7 +1607,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment auto a_block_buf_ping = make_dynamic_buffer( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 3eb0f986b3..d31ed19787 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -374,7 +374,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1249,7 +1249,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPackPerGroup * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1687,7 +1687,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPackPerGroup * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index 322cd3d162..909376e5f7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -370,7 +370,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1208,7 +1208,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1707,7 +1707,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 223670e3bc..6691c63484 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -422,7 +422,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor_packed( make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); } @@ -1886,7 +1886,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle get_warp_local_1d_id() % NWave, 0, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment auto a_block_buf_ping = make_dynamic_buffer( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 62d94c0bf8..92aab5af52 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -405,7 +405,7 @@ struct GridwiseMoeGemm __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1315,7 +1315,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1361,7 +1361,8 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, @@ -2027,7 +2028,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2077,7 +2078,7 @@ struct GridwiseMoeGemm make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); blockwise_gemm_pipeline.template Run( a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index fbfe2509ff..f092c9c1eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -410,7 +410,7 @@ struct GridwiseMoeGemmBlockScale __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1355,7 +1355,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1467,7 +1467,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; const auto b_scale_grid_buf_up = make_dynamic_buffer( @@ -2105,7 +2105,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2221,7 +2221,7 @@ struct GridwiseMoeGemmBlockScale make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; const auto b_scale_grid_buf_up = make_dynamic_buffer( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index fc156a878f..59693a5861 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -409,7 +409,7 @@ struct GridwiseMoeGemmMX __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber), make_tuple(NWave * NXdlPack * K0 * NkSwizzleNumber, @@ -1415,7 +1415,7 @@ struct GridwiseMoeGemmMX make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1508,7 +1508,7 @@ struct GridwiseMoeGemmMX make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, @@ -2123,7 +2123,7 @@ struct GridwiseMoeGemmMX get_warp_local_1d_id() % NWave, 0, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2221,7 +2221,7 @@ struct GridwiseMoeGemmMX make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 7238917920..9ccd334262 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -2319,7 +2319,7 @@ struct GridwiseMoeGemmMXBNS get_warp_local_1d_id() % NWave, 0, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -2417,7 +2417,7 @@ struct GridwiseMoeGemmMXBNS make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, diff --git a/include/ck/utility/workgroup_synchronization.hpp b/include/ck/utility/workgroup_synchronization.hpp index 24858fdbdc..af5b0808fb 100644 --- a/include/ck/utility/workgroup_synchronization.hpp +++ b/include/ck/utility/workgroup_synchronization.hpp @@ -32,7 +32,7 @@ static __device__ void gms_init(int NumWarps, int* p_control_bits) // all the workgroups in the synchronization group is supposed to call this function static __device__ void gms_barrier(int* p_control_bits) { - constexpr int mask = warpSize - 1; + constexpr int mask = WarpSize - 1; if((threadIdx.x & mask) == 0) { diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index df0f54c5ed..7184f99521 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -35,7 +35,7 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta) #elif 1 static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); - const uint32_t wrap_around_lane_delta = warpSize - lane_delta; + const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta; const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute( (__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast(v_local)); diff --git a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp index 869ab32c2e..1dcd62011a 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp @@ -95,7 +95,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 // constexpr index_t Block_M = Problem::BlockShape::Block_M0; // constexpr index_t Block_K = Problem::BlockShape::Block_K0; // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; - constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t WarpSize = ck_tile::get_warp_size(); // constexpr index_t NumWarps = Problem::BlockShape::NumWarps; constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS @@ -104,11 +104,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 static_assert(Block_K % KVector == 0); constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K - if constexpr(LanesPerK >= warpSize) + if constexpr(LanesPerK >= WarpSize) { // need multiple waves to load K - static_assert(LanesPerK % warpSize == 0); - constexpr index_t wavesPerK = LanesPerK / warpSize; + static_assert(LanesPerK % WarpSize == 0); + constexpr index_t wavesPerK = LanesPerK / WarpSize; if constexpr(wavesPerK > NumWarps) { // TODO: need multiple issues along K to load all data @@ -121,11 +121,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 make_tuple(number{}, // m0 number{}, // m1 number{}, // k0 - number{}, // k1 + number{}, // k1 number{}), // k2 - make_tuple(number{}, // m0 - number{}, // m1 - number{}, // k0 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 number{}, // k1 number<1>{}), // k2 number{}, // lds store vector(actually no explicit store) @@ -136,7 +136,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 make_tuple( make_pass_through_transform(number{}), make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); @@ -146,8 +146,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 else { // lanes within a wave load different M but same K - static_assert(warpSize % LanesPerK == 0); - constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + static_assert(WarpSize % LanesPerK == 0); + constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( @@ -156,9 +156,9 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 number{}, // m2 number{}, // k0 number{}), // k1 - make_tuple(number{}, // m0 + make_tuple(number{}, // m0 number{}, // m1 - number{}, // m2 + number{}, // m2 number{}, // k0 number<1>{}), // k1 number{}, // lds store vector(actually no explicit store) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 26f7e46f9f..30d07a4754 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -448,19 +448,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; - static_assert(warpSize * KVector >= kKPerBlock && - warpSize * KVector % kKPerBlock == 0); + static_assert(WarpSize * KVector >= kKPerBlock && + WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; - constexpr index_t LaneGroups = warpSize / LanesPerK; + constexpr index_t LaneGroups = WarpSize / LanesPerK; constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - return NumIssues * NumWarps * (warpSize * KVector + kPad); + return NumIssues * NumWarps * (WarpSize * KVector + kPad); } }(); @@ -516,18 +516,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; // how many lane (within a wave) to load K constexpr index_t LaneGroups = - warpSize / + WarpSize / LanesPerK; // how many groups (within a wave), they may load different N, but same K constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); @@ -538,9 +538,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, // n2 number{}, // k0 number{}), // k1 - make_tuple(number{}, + make_tuple(number{}, number{}, - number{}, + number{}, number{}, number<1>{}), number()>{}, @@ -569,18 +569,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; // for async-copy, this pad is between warps - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave - constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - // constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad); + // constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad); // constexpr index_t SingleVSize = // MakeVLdsBlockDescriptor().get_element_space_size(); constexpr index_t BufferSize = @@ -594,8 +594,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, // k0 number{}), // k1 make_tuple(number{}, - number{}, - number{}, + number{}, + number{}, number{}, number{}, number<1>{}), @@ -746,13 +746,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for global load - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave - constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp index 4f3f8bb7d3..336bdc806f 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp @@ -101,7 +101,7 @@ struct FusedMoeGemmShape static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; - static constexpr index_t BlockSize = warpSize * NumWarps; + static constexpr index_t BlockSize = WarpSize * NumWarps; // some assert static_assert(Block_M0 == Block_M1); diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 4166c1c602..d3c98d7bca 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -381,7 +381,7 @@ struct MoeSortingKernel } // reduce single pixel within a wave - template + template __device__ static constexpr T wave_reduce(T local, F reduce_f, number = {}) { // constexpr int wave_size = 64; @@ -618,7 +618,7 @@ struct MoeSortingKernel { const index_t prefill_token = topk_mdiv.div(numel); // TODO: only support expert-tile like 8, 16, 32 - static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile; + static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile; { index_t eid = tid / experts_per_wave; index_t expert_offset = cumsum[eid] + @@ -686,7 +686,7 @@ struct MoeSortingKernel void* smem) const { const index_t tid = static_cast(threadIdx.x); - const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize); + const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize); const index_t lid = __lane_id(); constexpr index_t block_size = 256; // blockDim.x; const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor; @@ -791,7 +791,7 @@ struct MoeSortingKernel // NOTE: under this block can never use __syncthreads! int i_e_ = 0; int local_cumsum_ = 0; - for(; i_e_ < num_experts; i_e_ += warpSize) + for(; i_e_ < num_experts; i_e_ += WarpSize) { int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0); int local_cnt = smem_cumsum(i_e_ + lid + 1); @@ -836,7 +836,7 @@ struct MoeSortingKernel // cumsum padded in case local cumsum is zero, but // pre_sumsum has value, which will result int // zero local cumsum(but we want at least padded) - wave_cumsum(local_cumsum_); + wave_cumsum(local_cumsum_); if((i_e_ + lid) < num_experts) smem_cumsum(i_e_ + lid + 1) = local_cumsum_; @@ -844,7 +844,7 @@ struct MoeSortingKernel if constexpr(Problem::LocalExpertMasking) { local_masking += pre_cumsum_masking; - wave_cumsum(local_masking); + wave_cumsum(local_masking); if((i_e_ + lid) < num_experts) smem_cumdup(i_e_ + lid + 1) = local_masking; } @@ -854,7 +854,7 @@ struct MoeSortingKernel // than 0(which is not we want) __builtin_amdgcn_s_waitcnt(0xc07f); } - if((lid + i_e_ - warpSize) == (num_experts - 1)) + if((lid + i_e_ - WarpSize) == (num_experts - 1)) { *p_total_tokens_post_pad = local_cumsum_; } @@ -1091,7 +1091,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size() return chunk * sizeof(index_t); }; -template +template CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number = {}) { // constexpr int wave_size = 64; @@ -1456,7 +1456,7 @@ struct MoeSortingMultiPhaseKernel_P1 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return BLOCK_SIZE / warpSize * sizeof(IndexType); + return BLOCK_SIZE / WarpSize * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1498,8 +1498,8 @@ struct MoeSortingMultiPhaseKernel_P1 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); } - index_t lane_id = threadIdx.x % warpSize; - index_t wave_id = threadIdx.x / warpSize; + index_t lane_id = threadIdx.x % WarpSize; + index_t wave_id = threadIdx.x / WarpSize; // reduce cross wave IndexType* s = reinterpret_cast(smem); @@ -1512,7 +1512,7 @@ struct MoeSortingMultiPhaseKernel_P1 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++) + for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++) { c += s[i]; } @@ -1601,7 +1601,7 @@ struct MoeSortingMultiPhaseKernel_P01 // in byte CK_TILE_HOST static constexpr auto GetSmemSize() { - return BLOCK_SIZE / warpSize * sizeof(IndexType); + return BLOCK_SIZE / WarpSize * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1685,8 +1685,8 @@ struct MoeSortingMultiPhaseKernel_P01 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); } - index_t lane_id = threadIdx.x % warpSize; - index_t wave_id = threadIdx.x / warpSize; + index_t lane_id = threadIdx.x % WarpSize; + index_t wave_id = threadIdx.x / WarpSize; // reduce cross wave IndexType* s = reinterpret_cast(smem); @@ -1700,7 +1700,7 @@ struct MoeSortingMultiPhaseKernel_P01 if(threadIdx.x == 0) { index_t c = 0; - for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++) + for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++) { c += s[i]; } @@ -1777,7 +1777,7 @@ struct MoeSortingMultiPhaseKernel_P2 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { // return 2 * BLOCK_SIZE * sizeof(IndexType); - return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType); + return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType); } // reduce single pixel within a wave @@ -1802,8 +1802,8 @@ struct MoeSortingMultiPhaseKernel_P2 IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; - index_t wave_id = threadIdx.x / warpSize; - index_t lane_id = threadIdx.x % warpSize; + index_t wave_id = threadIdx.x / WarpSize; + index_t lane_id = threadIdx.x % WarpSize; IndexType prev_cumsum_a = 0; IndexType prev_cumsum_b = 0; @@ -1848,22 +1848,22 @@ struct MoeSortingMultiPhaseKernel_P2 IndexType cumsum_b = b_; // Note: we first cumsum local round, then add previous cumsum - impl::moe_sorting_wave_cumsum(cumsum_a); - impl::moe_sorting_wave_cumsum(cumsum_b); + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b; + s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -1978,7 +1978,7 @@ struct MoeSortingMultiPhaseKernel_P3 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return (4 + BLOCK_SIZE / warpSize) * sizeof(IndexType); + return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -1995,8 +1995,8 @@ struct MoeSortingMultiPhaseKernel_P3 WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); int eid = blockIdx.x; - int wave_id = threadIdx.x / warpSize; - int lane_id = threadIdx.x % warpSize; + int wave_id = threadIdx.x / WarpSize; + int lane_id = threadIdx.x % WarpSize; int e_start = p_expert_cumsum[eid]; int e_end = p_expert_cumsum[eid + 1]; if constexpr(Problem::SkipExpertsWithZeroTokens) @@ -2026,17 +2026,17 @@ struct MoeSortingMultiPhaseKernel_P3 int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not int cumsum = i_show; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2081,7 +2081,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_) { constexpr index_t BLOCK_SIZE = 256; // hardcoded 256 const index_t expert_cumsum_elem = num_experts_ + 1; - return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int); + return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int); } } // namespace impl @@ -2186,15 +2186,15 @@ struct MoeSortingMultiPhaseKernel_P23 const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize; IndexType* p_total_tokens_post_pad = reinterpret_cast(kargs.p_total_tokens_post_pad); IndexType* p_sorted_expert_ids = reinterpret_cast(kargs.p_sorted_expert_ids); const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; - index_t wave_id = threadIdx.x / warpSize; - index_t lane_id = threadIdx.x % warpSize; + index_t wave_id = threadIdx.x / WarpSize; + index_t lane_id = threadIdx.x % WarpSize; IndexType prev_cumsum_a = 0; IndexType prev_cumsum_b = 0; @@ -2239,22 +2239,22 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType cumsum_b = b_; // Note: we first cumsum local round, then add previous cumsum - impl::moe_sorting_wave_cumsum(cumsum_a); - impl::moe_sorting_wave_cumsum(cumsum_b); + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum_a; - s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b; + s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev_a = s[4 + i_w]; - IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize]; prev_a = wave_id > i_w ? prev_a : 0; // mask out prev_b = wave_id > i_w ? prev_b : 0; // mask out cumsum_a += prev_a; @@ -2324,13 +2324,13 @@ struct MoeSortingMultiPhaseKernel_P23 IndexType* s = reinterpret_cast(smem); MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); - IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize; const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); int eid = blockIdx.x; - int wave_id = threadIdx.x / warpSize; - int lane_id = threadIdx.x % warpSize; + int wave_id = threadIdx.x / WarpSize; + int lane_id = threadIdx.x % WarpSize; int e_start = p_expert_cumsum_smem[eid]; int e_end = p_expert_cumsum_smem[eid + 1]; if constexpr(Problem::SkipExpertsWithZeroTokens) @@ -2390,17 +2390,17 @@ struct MoeSortingMultiPhaseKernel_P23 int i_topk = x - 1; // topk of this token int i_show = x != 0 ? 1 : 0; // has this token or not int cumsum = i_show; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2441,17 +2441,17 @@ struct MoeSortingMultiPhaseKernel_P23 cumsum_store += i_show[j]; }); int cumsum = cumsum_store; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; @@ -2496,17 +2496,17 @@ struct MoeSortingMultiPhaseKernel_P23 int i_topk_1 = x1 - 1; // topk of this token int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not int cumsum = i_show_0 + i_show_1; - impl::moe_sorting_wave_cumsum(cumsum); + impl::moe_sorting_wave_cumsum(cumsum); __syncthreads(); - if(lane_id == warpSize - 1) + if(lane_id == WarpSize - 1) { s[4 + wave_id] = cumsum; } __syncthreads(); // reduce cross wave - static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) { IndexType prev = s[4 + i_w]; prev = wave_id > i_w ? prev : 0; // mask out cumsum += prev; diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp index 629f0ee8f1..0c8baaf191 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp @@ -303,7 +303,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_K = Problem::BlockShape::Block_K0; // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; - constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t WarpSize = ck_tile::get_warp_size(); constexpr index_t NumWarps = Problem::BlockShape::NumWarps; constexpr index_t KPack = GetSmemKPack_A(); // LDS @@ -312,11 +312,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy static_assert(Block_K % KVector == 0); constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K - if constexpr(LanesPerK >= warpSize) + if constexpr(LanesPerK >= WarpSize) { // need multiple waves to load K - static_assert(LanesPerK % warpSize == 0); - constexpr index_t wavesPerK = LanesPerK / warpSize; + static_assert(LanesPerK % WarpSize == 0); + constexpr index_t wavesPerK = LanesPerK / WarpSize; if constexpr(wavesPerK > NumWarps) { // TODO: need multiple issues along K to load all data @@ -329,11 +329,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy make_tuple(number{}, // m0 number{}, // m1 number{}, // k0 - number{}, // k1 + number{}, // k1 number{}), // k2 - make_tuple(number{}, // m0 - number{}, // m1 - number{}, // k0 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 number{}, // k1 number<1>{}), // k2 number{}, // lds store vector(actually no explicit store) @@ -344,7 +344,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy make_tuple( make_pass_through_transform(number{}), make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); @@ -354,8 +354,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy else { // lanes within a wave load different M but same K - static_assert(warpSize % LanesPerK == 0); - constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + static_assert(WarpSize % LanesPerK == 0); + constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( @@ -364,9 +364,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy number{}, // m2 number{}, // k0 number{}), // k1 - make_tuple(number{}, // m0 + make_tuple(number{}, // m0 number{}, // m1 - number{}, // m2 + number{}, // m2 number{}, // k0 number<1>{}), // k1 number{}, // lds store vector(actually no explicit store) @@ -398,7 +398,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy constexpr index_t Block_M = Problem::BlockShape::Block_M0; constexpr index_t Block_K = Problem::BlockShape::Block_K0; // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; - constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t WarpSize = ck_tile::get_warp_size(); constexpr index_t NumWarps = Problem::BlockShape::NumWarps; constexpr index_t KPack = GetSmemKPack_A(); // LDS @@ -407,11 +407,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy static_assert(Block_K % KVector == 0); constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K - if constexpr(LanesPerK >= warpSize) + if constexpr(LanesPerK >= WarpSize) { // need multiple waves to load K - static_assert(LanesPerK % warpSize == 0); - constexpr index_t wavesPerK = LanesPerK / warpSize; + static_assert(LanesPerK % WarpSize == 0); + constexpr index_t wavesPerK = LanesPerK / WarpSize; if constexpr(wavesPerK >= NumWarps) { // TODO: need multiple issues along K to load all data @@ -424,11 +424,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy make_tuple(number{}, // m0 number{}, // m1 number{}, // k0 - number{}, // k1 + number{}, // k1 number{}), // k2 - make_tuple(number{}, // m0 - number{}, // m1 - number{}, // k0 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 number{}, // k1 number<1>{}), // k2 number{}, // lds load vector @@ -439,7 +439,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy make_tuple( make_merge_transform(make_tuple(number{}, number{})), make_merge_transform(make_tuple( - number{}, number{}, number{}))), + number{}, number{}, number{}))), make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -449,8 +449,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy else { // lanes within a wave load different M but same K - static_assert(warpSize % LanesPerK == 0); - constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + static_assert(WarpSize % LanesPerK == 0); + constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( @@ -459,9 +459,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy number{}, // m2 number{}, // k0 number{}), // k1 - make_tuple(number{}, // m0 + make_tuple(number{}, // m0 number{}, // m1 - number{}, // m2 + number{}, // m2 number{}, // k0 number<1>{}), // k1 number{}, // lds load vector diff --git a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp index b038472fcf..ad513dbd11 100644 --- a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp +++ b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp @@ -26,7 +26,7 @@ struct TileImageToColumnShape static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp; - static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock; + static constexpr index_t kBlockSize = get_warp_size() * kMWarpPerBlock * kKWarpPerBlock; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp index 15ac021631..26437c7126 100644 --- a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp @@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync // | w0 | w1 | w2 | w3 | -----> | w0123 | // // -> also store data from every wave into LDS - constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + constexpr index_t num_warps = BlockShape::BlockSize / WarpSize; return num_warps * 4 * thread_buf_size * sizeof(float); } @@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync const index_t lane_id = get_lane_id(); const index_t warp_id = get_warp_id(); constexpr auto num_reduce_warps = GetReduceWarps(); - constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + constexpr index_t num_warps = BlockShape::BlockSize / WarpSize; const index_t smem_offset = warp_id; // skip if nonthing to do diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index d6ca98e7b4..6a1f926a9a 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -210,7 +210,7 @@ struct BlockReduce2dCrossWarpSync // | w0 | w1 | w2 | w3 | -----> | w0123 | // // -> also store data from every wave into LDS - constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); return num_warps * thread_buf_size * sizeof(DataType); } @@ -226,7 +226,7 @@ struct BlockReduce2dCrossWarpSync const index_t lane_id = get_lane_id(); const index_t warp_id = get_warp_id(); constexpr auto num_reduce_warps = GetReduceWarps(); - constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); const index_t smem_offset = warp_id; // skip if nonthing to do From cc98a41f465108af2ecf5168c7bd7844a64b6fc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 17 Jun 2025 22:25:56 +0200 Subject: [PATCH 018/103] Fix Add in dynamic buffer for fp32/i8 (#2351) * Fix Add in dynamic buffer for fp32/i8 * fixes * Fix --- .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 6 +-- include/ck/utility/dynamic_buffer.hpp | 52 ++----------------- 2 files changed, 7 insertions(+), 51 deletions(-) mode change 100755 => 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp mode change 100755 => 100644 include/ck/utility/dynamic_buffer.hpp diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp old mode 100755 new mode 100644 index f1c0ec1c68..d45ed79ae3 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1841,7 +1841,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, CShuffleDataType, // typename SrcData, - CShuffleDataType, // typename DstData, + AccDataType, // typename DstData, decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), Sequence<0, 1, 2, 3>, // typename DimAccessOrder, @@ -2591,7 +2591,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, CShuffleDataType, // typename SrcData, - CShuffleDataType, // typename DstData, + AccDataType, // typename DstData, decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), Sequence<0, 1, 2, 3>, // typename DimAccessOrder, diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp old mode 100755 new mode 100644 index eb35c34498..2debd09c2d --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -139,8 +139,7 @@ struct DynamicBuffer template >::type, - typename scalar_type>::type>::value || - !is_native_type(), + typename scalar_type>::type>::value, bool>::type = false> __host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x) { @@ -160,37 +159,7 @@ struct DynamicBuffer { auto tmp = this->template Get(i, is_valid_element); using scalar_t = typename scalar_type>::type; - -#if defined(__gfx942__) || defined(__gfx950__) - - // Properly handle addition for all low-precision types - if constexpr(is_same_v || is_same_v) - { - if constexpr(is_scalar_type::value) - { - // Scalar type: Convert to float, add, convert back - auto result = - type_convert(type_convert(x) + type_convert(tmp)); - this->template Set(i, is_valid_element, result); - } - else - { - // Vector type - constexpr auto vector_size = scalar_type>::vector_size; - const vector_type a_vector{tmp}; - const vector_type b_vector{x}; - - // Process each element of the vector in higher precision - static_for<0, vector_size, 1>{}([&](auto idx) { - auto result = type_convert( - type_convert(a_vector.template AsType()[idx]) + - type_convert(b_vector.template AsType()[idx])); - this->template Set(i + idx, is_valid_element, result); - }); - } - } -#else - // handle bfloat addition + // handle bfloat addition if constexpr(is_same_v) { if constexpr(is_scalar_type::value) @@ -218,8 +187,6 @@ struct DynamicBuffer { this->template Set(i, is_valid_element, x + tmp); } - -#endif } } @@ -273,20 +240,9 @@ struct DynamicBuffer if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - using vector_t = typename vector_type_maker, t_per_x>::type::type; - vector_t tmp; - - if constexpr(is_same_v, vector_t>) - { - tmp = x; - } - else - { - __builtin_memcpy(&tmp, &x, sizeof(vector_t)); - } amd_buffer_store, t_per_x, coherence>( - tmp, p_data_, i, is_valid_element, element_space_size_ / PackedSize); + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && is_same>::type, int8_t>::value && From cdfd7722bfda0181e9ccb75db4161fb95fdef353 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 17 Jun 2025 13:56:30 -0700 Subject: [PATCH 019/103] Revert "Shard several of the most costly targets. (#2266)" (#2361) This reverts commit 3a0cb2796605082cdbac4d1649397b9435e49556. --- .gitignore | 3 - cmake/ShardInstantiation.cmake | 116 ------------------ cmake/call_shard.in | 15 --- cmake/instantiate_shard.in | 9 -- include/ck/utility/filter_tuple.hpp | 66 ---------- .../gpu/grouped_convolution_forward_xdl.inc | 3 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 51 +------- ..._ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp} | 38 +++--- ...d_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp} | 40 +++--- ...wd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp} | 64 +++++----- ...c_gkyxc_nhwgk_int8_mem_inter_instance.cpp} | 100 +++++++-------- ...wgc_gkyxc_nhwgk_int8_mem_inter_instance.in | 80 ------------ ...c_gkyxc_nhwgk_int8_mem_intra_instance.cpp} | 100 +++++++-------- ...wgc_gkyxc_nhwgk_int8_mem_intra_instance.in | 80 ------------ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 109 +++------------- ...dhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp | 111 +++++++++++++++++ ...ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in | 66 ---------- ...ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp | 111 +++++++++++++++++ ..._ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in | 65 ---------- ...gcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp | 54 ++++++++ ...ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in | 65 ---------- ...ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp | 54 ++++++++ ..._ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in | 63 ---------- ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp | 53 ++++++++ ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp | 53 ++++++++ ...ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp | 9 -- ...ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp | 9 -- ...ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp | 9 -- ...ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp | 9 -- ...ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp | 9 -- ...ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp | 9 -- ...ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp | 9 -- ...ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp | 9 -- ...gkczyx_ngkdhw_bf16_mem_inter_instance.cpp} | 53 ++++---- ...w_gkczyx_ngkdhw_bf16_mem_inter_instance.in | 64 ---------- ..._gkczyx_ngkdhw_bf16_mem_intra_instance.cpp | 55 +++++++++ ...w_gkczyx_ngkdhw_bf16_mem_intra_instance.in | 65 ---------- ..._gkczyx_ngkdhw_f16_mem_inter_instance.cpp} | 53 ++++---- ..._gkczyx_ngkdhw_f16_mem_intra_instance.cpp} | 69 +++++------ ..._gkczyx_ngkdhw_f32_mem_inter_instance.cpp} | 69 +++++------ ..._gkczyx_ngkdhw_f32_mem_intra_instance.cpp} | 69 +++++------ 41 files changed, 820 insertions(+), 1318 deletions(-) delete mode 100644 cmake/ShardInstantiation.cmake delete mode 100644 cmake/call_shard.in delete mode 100644 cmake/instantiate_shard.in delete mode 100644 include/ck/utility/filter_tuple.hpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.in => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp} (53%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp} (71%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp} (64%) rename library/src/tensor_operation_instance/gpu/{grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in => grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp} (54%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in rename library/src/tensor_operation_instance/gpu/{grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in => grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp} (54%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in => mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp} (64%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in => mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp} (64%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp} (59%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.inc => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp} (59%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp} (59%) diff --git a/.gitignore b/.gitignore index e4dd8f7513..599ef99e35 100644 --- a/.gitignore +++ b/.gitignore @@ -68,6 +68,3 @@ build*/ # Python cache __pycache__/ - -.cache/ - diff --git a/cmake/ShardInstantiation.cmake b/cmake/ShardInstantiation.cmake deleted file mode 100644 index 47a5d0c48c..0000000000 --- a/cmake/ShardInstantiation.cmake +++ /dev/null @@ -1,116 +0,0 @@ -# Function to generate templated instantiation functions and caller function. - -# In order to reduce build times, we split the instantiation of template functions into multiple files. -# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions, -# which can be placed the TEMPLATE_FILE (typically a .in file). - -# This CMake function generates the instantiation functions and a caller function that calls all the instantiation -# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of -# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard, -# and generates a caller function that calls all the instantiation functions. - -# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation -# of the template functions in the caller function, and that code is automatically generated by this function. - -# In addition to the user-supplied template, this CMake function uses two generic templates: -# -# 1. `instantiate_shard.in`: This is the template for the instantiation functions. -# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions. - -# This function takes the following arguments: -# -# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`). -# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions. -# - NUM_SHARDS: The number of shards to generate. -# - OUTPUT_DIR: The build directory where the generated source files will be placed. -# - SRC_LIST: The list of source files to which the generated source files will be added. - - -function(generate_sharded_instantiations) - cmake_parse_arguments( - GEN_SHARDED - # No boolean arguments - "" - # Single-value arguments - "INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST" - # No multi-value arguments. - "" - ${ARGN} - ) - if (NOT GEN_SHARDED_INSTANCES_NAME) - message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations") - endif() - if (NOT GEN_SHARDED_TEMPLATE_FILE) - message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations") - endif() - if (NOT GEN_SHARDED_NUM_SHARDS) - message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations") - endif() - if(NOT GEN_SHARDED_OUTPUT_DIR) - message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations") - endif() - if (NOT GEN_SHARDED_SRC_LIST) - message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations") - endif() - - file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR}) - - - set(GENERATED_SOURCE_FILES "") - set(EXTERN_TEMPLATE_STATEMENTS "") - set(CALL_STATEMENTS "") - message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") - - set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}") - - # Generate the inc file with the template function defintions. - # This include file will hold the template function definitions and a using alias for all the shard - # instantiation functions. - configure_file( - "${GEN_SHARDED_TEMPLATE_FILE}" - "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc" - @ONLY - ) - - # Generate the sharded instantiation functions. - # This is where the build parallelization happens. - # Each of these source files will contain a single instantiation function for a shard, - # which will be called sequentially by the caller function. - set(INC_DIR "${GEN_SHARDED_INC_DIR}") - math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1") - foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID}) - set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}") - set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp") - set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in") - configure_file( - "${SHARD_FUNCTION_TEMPLATE}" - "${SHARD_FUNCTION_PATH}" - @ONLY - ) - list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}") - set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>") - list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)") - list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)") - endforeach() - - # Join the include statements, the extern template declarations, and the call statements each - # into a single string for variable substitution in the caller function. - string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}") - string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}") - string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}") - - # Generate the caller function. - set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp") - set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in") - configure_file( - "${FUNCTION_TEMPLATE}" - "${CALLER_FUNCTION_PATH}" - @ONLY - ) - list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}") - - # Add the generated source files to the list of source files. - # This allows the generated source files to be included in the build. - list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES}) - set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE) -endfunction() \ No newline at end of file diff --git a/cmake/call_shard.in b/cmake/call_shard.in deleted file mode 100644 index daba79b055..0000000000 --- a/cmake/call_shard.in +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "@INSTANCES@.inc" - -namespace ck::tensor_operation::device::instance { - -@EXTERN_TEMPLATE_STATEMENTS@; - -void add_@INSTANCES@( - @INSTANCES@& instances) { -@CALL_STATEMENTS@; -} - -} // namespace ck::tensor_operation::device::instance diff --git a/cmake/instantiate_shard.in b/cmake/instantiate_shard.in deleted file mode 100644 index dbc0af17a9..0000000000 --- a/cmake/instantiate_shard.in +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "@INSTANCES@.inc" - -namespace ck::tensor_operation::device::instance { -template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>( - @INSTANCES@& instances); -} // namespace ck::tensor_operation::device::instance diff --git a/include/ck/utility/filter_tuple.hpp b/include/ck/utility/filter_tuple.hpp deleted file mode 100644 index c2e378b879..0000000000 --- a/include/ck/utility/filter_tuple.hpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include - -#include "ck/utility/functional.hpp" -#include "ck/utility/sequence.hpp" - -namespace ck::util { - -template -struct filter_tuple_by_modulo -{ - // Validate Stride and Offset. - static_assert(Stride > 0, "Offset must be positive."); - static_assert(Offset >= 0 && Offset < Stride, - "Offset must be positive and less than the stride."); - - // Generate filtered indices for this stride and offset. - static constexpr int new_size = (std::tuple_size_v + Stride - Offset - 1) / Stride; - - template - static constexpr auto to_index(std::index_sequence) - { - return std::index_sequence<(Offset + Is * Stride)...>{}; - } - - using filtered_indices = decltype(to_index(std::make_index_sequence{})); - - // Helper struct to construct the new tuple type from the filtered indices. - template - struct make_filtered_tuple_type_impl; - - template - struct make_filtered_tuple_type_impl> - { - using type = std::tuple...>; - }; - - using type = typename make_filtered_tuple_type_impl::type; -}; - -// Filter a tuple with a stride and offset. -// -// Tuple is a std::tuple or equivalent -// Stride is a positive integer -// Offset is a positive integer smaller than ofset -// -// Evaluates to a smaller tuple type from elements of T with stride M and offset I. -// -// Can be used to filter a tuple of types for sharded instantiations. -template -using filter_tuple_by_modulo_t = typename filter_tuple_by_modulo::type; - -// Example compile-time test: -// using OriginalTuple = -// std::tuple; -// using NewTuple_Every3rdFrom2nd = filter_tuple_by_modulo_t; -// static_assert(std::is_same_v>, -// "Test Case 1 Failed: Every 3rd from 2nd"); - -} // namespace ck::util diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index a3f2515099..b018737932 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -688,6 +688,7 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances( std::vector>>; - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -template -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances_shard([[maybe_unused]] - device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances& instances) + PassThrough>>>& instances) { add_device_operation_instances( instances, - ck::util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); } -} // namespace ck::tensor_operation::device::instance \ No newline at end of file +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp similarity index 71% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp index 88c84adfe2..4ca1b2b85e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -3,11 +3,13 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" -#include "ck/utility/filter_tuple.hpp" -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances = +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( std::vector>>; - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -template -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances_shard( - device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances& instances) + PassThrough>>>& instances) { add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); + ConvFwdDefault>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); + ConvFwd1x1P0>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); + ConvFwd1x1S1P0>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp index 13fb583725..e3a12fd5f4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -3,11 +3,13 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" -#include "ck/utility/filter_tuple.hpp" -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances = +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( std::vector>>; - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -template -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances_shard( - device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances& instances) + PassThrough>>>& instances) { add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp similarity index 54% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp index 7571dff883..f667481fa4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp @@ -1,62 +1,66 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances = - std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances& instances) + PassThrough>>>& instances) { add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in deleted file mode 100644 index d8b35bda68..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in +++ /dev/null @@ -1,80 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances = - std::vector>>; - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -template -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances_shard( - device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances& instances) -{ - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault, - Interwave>, - Shards, - ShardIndex>{}); - - add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0, - Interwave>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0, - Interwave>, - Shards, - ShardIndex>{}); - - add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp similarity index 54% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp index 91a2444241..2ff2c7f51f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp @@ -1,62 +1,66 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances = - std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances& instances) + PassThrough>>>& instances) { add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in deleted file mode 100644 index 125e16139d..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in +++ /dev/null @@ -1,80 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances = - std::vector>>; - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -template -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances_shard( - device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances& instances) -{ - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault, - Intrawave>, - Shards, - ShardIndex>{}); - - add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0, - Intrawave>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0, - Intrawave>, - Shards, - ShardIndex>{}); - - add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 1d9d75a104..f8efa5a7c1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -11,6 +11,8 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp @@ -30,13 +32,23 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp -xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp @@ -59,99 +71,6 @@ xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp ) -# Add generated files for sharded instantiations. -include(ShardInstantiation) - -set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances - TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in - NUM_SHARDS 8 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl -) -set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances - TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in - NUM_SHARDS 8 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl -) - -set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances - TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in - NUM_SHARDS 10 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/mem -) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances - TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in - NUM_SHARDS 10 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/mem -) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances - TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in - NUM_SHARDS 10 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/mem -) - -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances - TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in - NUM_SHARDS 10 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/mem -) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances - TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in - NUM_SHARDS 10 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/mem -) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances - TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in - NUM_SHARDS 10 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/mem -) - -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances - TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in - NUM_SHARDS 12 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/comp -) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances - TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in - NUM_SHARDS 12 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/comp -) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances - TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in - NUM_SHARDS 12 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/comp -) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances - TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in - NUM_SHARDS 12 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/comp -) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp new file mode 100644 index 0000000000..a94f687ef8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in deleted file mode 100644 index e1a6e6c0c4..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances = - std::vector>>; - -template -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances_shard( - device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances& instances) -{ - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance - diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp new file mode 100644 index 0000000000..0c63345e7f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in deleted file mode 100644 index 6d196ad71f..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in +++ /dev/null @@ -1,65 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances = - std::vector>>; - -template -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances_shard( - device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances& instances) -{ - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance - diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp new file mode 100644 index 0000000000..43241454a5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in deleted file mode 100644 index 4c67e4912c..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in +++ /dev/null @@ -1,65 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances = - std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances& instances) -{ - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance - diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp new file mode 100644 index 0000000000..d02d9f6778 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in deleted file mode 100644 index 0fbefa3bbc..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in +++ /dev/null @@ -1,63 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances = - std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances& instances) -{ - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp new file mode 100644 index 0000000000..060eebebc1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp new file mode 100644 index 0000000000..85b088f416 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp deleted file mode 100644 index da2f3dc1fa..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_1of8.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" - -namespace ck::tensor_operation::device::instance { -template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 0>( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp deleted file mode 100644 index 5d551833c0..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_2of8.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" - -namespace ck::tensor_operation::device::instance { -template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 1>( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp deleted file mode 100644 index 715cbf6beb..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_3of8.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" - -namespace ck::tensor_operation::device::instance { -template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 2>( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp deleted file mode 100644 index cf2a9f4023..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_4of8.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" - -namespace ck::tensor_operation::device::instance { -template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 3>( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp deleted file mode 100644 index 085b2904d6..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_5of8.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" - -namespace ck::tensor_operation::device::instance { -template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 4>( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp deleted file mode 100644 index 18b1e0c6d9..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_6of8.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" - -namespace ck::tensor_operation::device::instance { -template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 5>( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp deleted file mode 100644 index b95f1d1229..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_7of8.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" - -namespace ck::tensor_operation::device::instance { -template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 6>( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp deleted file mode 100644 index afe3e5d19f..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance_8of8.cpp +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.inc" - -namespace ck::tensor_operation::device::instance { -template void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_sharded<8, 7>( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances); -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp index c87783eed9..fac3098341 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp @@ -1,14 +1,15 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" -#include "ck/utility/filter_tuple.hpp" -namespace ck::tensor_operation::device::instance { +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances = +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances& instances) + PassThrough>>>& instances) { - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, NGCDHW, GKCZYX, Empty_Tuple, NGKDHW, - ConvFwd1x1P0>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, NGCDHW, GKCZYX, Empty_Tuple, NGKDHW, - ConvFwd1x1S1P0>, - Shards, - ShardIndex>{}); + ConvFwd1x1S1P0, + Interwave>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in deleted file mode 100644 index 2586bc0f16..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in +++ /dev/null @@ -1,64 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances = - std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances& instances) -{ - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>, - Shards, - ShardIndex>{}); - add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>, - Shards, - ShardIndex>{}); - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..f3eccc7dc8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in deleted file mode 100644 index 7405f86a5f..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in +++ /dev/null @@ -1,65 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances = - std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances& instances) -{ - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>, - Shards, - ShardIndex>{}); - add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>, - Shards, - ShardIndex>{}); - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance - diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp index ca6d571be1..abea0bea81 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp @@ -1,14 +1,15 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" -#include "ck/utility/filter_tuple.hpp" -namespace ck::tensor_operation::device::instance { +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances = +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances) + PassThrough>>>& instances) { - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, NGCDHW, GKCZYX, Empty_Tuple, NGKDHW, - ConvFwd1x1P0>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - util::filter_tuple_by_modulo_t{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, NGCDHW, GKCZYX, Empty_Tuple, NGKDHW, - ConvFwd1x1S1P0>, - Shards, - ShardIndex>{}); + ConvFwd1x1S1P0, + Interwave>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp index 24d6b66976..ba5d9fb1de 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp @@ -3,11 +3,13 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" -namespace ck::tensor_operation::device::instance { +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances = +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances& instances) + PassThrough>>>& instances) { add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.inc b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.inc rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp index 38ed240fab..5a2c4a0d5b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.inc +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp @@ -3,11 +3,13 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" -namespace ck::tensor_operation::device::instance { +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances = +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances& instances) + PassThrough>>>& instances) { add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp index 38ed240fab..701b8eb4a4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp @@ -3,11 +3,13 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" -namespace ck::tensor_operation::device::instance { +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances = +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances& instances) + PassThrough>>>& instances) { add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>{}); add_device_operation_instances(instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>, - Shards, - ShardIndex>{}); + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From df54667102a3a1183fa55872eb6889717b42fde6 Mon Sep 17 00:00:00 2001 From: John Afaganis Date: Tue, 17 Jun 2025 15:29:45 -0600 Subject: [PATCH 020/103] Add missing copyright headers (#2359) * Add missing copyright headers * empty commit --- example/ck_tile/18_flatmm/script/smoke_test_basic.sh | 4 ++++ example/ck_tile/35_batched_transpose/script/perf_test.sh | 5 ++++- example/ck_tile/35_batched_transpose/script/run_full_test.sh | 4 ++++ example/ck_tile/35_batched_transpose/script/smoke_test.sh | 5 ++++- .../test_batched_gemm_device_utils.hpp | 3 +++ test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc | 3 +++ test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc | 3 +++ test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc | 3 +++ .../test_gemm_universal_streamk_ut_cases_bf16.inc | 3 +++ .../test_gemm_universal_streamk_ut_cases_fp16.inc | 3 +++ .../test_gemm_universal_streamk_ut_cases_fp8.inc | 3 +++ 11 files changed, 37 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/18_flatmm/script/smoke_test_basic.sh b/example/ck_tile/18_flatmm/script/smoke_test_basic.sh index a3fc61cc31..6bcec3a812 100755 --- a/example/ck_tile/18_flatmm/script/smoke_test_basic.sh +++ b/example/ck_tile/18_flatmm/script/smoke_test_basic.sh @@ -1,4 +1,8 @@ #!/bin/bash + +# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + EXE="$(find . -name tile_example_flatmm_basic -type f | head -n 1)" KNAME=1 diff --git a/example/ck_tile/35_batched_transpose/script/perf_test.sh b/example/ck_tile/35_batched_transpose/script/perf_test.sh index 7ecfefc580..dde646eb2a 100755 --- a/example/ck_tile/35_batched_transpose/script/perf_test.sh +++ b/example/ck_tile/35_batched_transpose/script/perf_test.sh @@ -1,5 +1,8 @@ #!/bin/sh +# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + EXE=./build/bin/tile_example_batched_transpose for pr in "fp8" "fp16" "bf16"; do @@ -8,4 +11,4 @@ $EXE -pr=$pr -N=1 -C=1024 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' $EXE -pr=$pr -N=1 -C=1024 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' $EXE -pr=$pr -N=1 -C=4096 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' -done \ No newline at end of file +done diff --git a/example/ck_tile/35_batched_transpose/script/run_full_test.sh b/example/ck_tile/35_batched_transpose/script/run_full_test.sh index 4d0c988912..bd42959256 100755 --- a/example/ck_tile/35_batched_transpose/script/run_full_test.sh +++ b/example/ck_tile/35_batched_transpose/script/run_full_test.sh @@ -1,4 +1,8 @@ #!/bin/bash + +# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd first need to build the tile_example_batched_transpose executables in ../build/bin/ # diff --git a/example/ck_tile/35_batched_transpose/script/smoke_test.sh b/example/ck_tile/35_batched_transpose/script/smoke_test.sh index fdc01a2eb4..5ba2743364 100755 --- a/example/ck_tile/35_batched_transpose/script/smoke_test.sh +++ b/example/ck_tile/35_batched_transpose/script/smoke_test.sh @@ -1,5 +1,8 @@ #!/bin/sh +# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + EXE=./build/bin/tile_example_batched_transpose for pr in "fp8" "fp16" "bf16"; do @@ -24,4 +27,4 @@ $EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW' $EXE -pr=$pr -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' $EXE -pr=$pr -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW' -done \ No newline at end of file +done diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp index 7d20ee4827..f8f621e9eb 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp @@ -1,5 +1,8 @@ #pragma once +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #include #include diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc index 233f86ef43..c344d10434 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc @@ -1,3 +1,6 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM) diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc index adc84848f2..309b212249 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc @@ -1,3 +1,6 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM) diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc index b831e15e9c..770107a2df 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc @@ -1,3 +1,6 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM) diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc index 22977866b5..5cefd911a7 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc @@ -1,3 +1,6 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_KN, SmallM) diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc index 99c8e6d163..6deb867cd3 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc @@ -1,3 +1,6 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_KN, SmallM) diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc index b98ee92800..43140e0ef4 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc @@ -1,3 +1,6 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_KN, SmallM) From 0eb8974502df073be0e131f25435a30ecbf9a656 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 18 Jun 2025 08:27:46 +0800 Subject: [PATCH 021/103] [CK_TILE] Support multi-config in tile_example_gemm_universal (#2240) * [CK_TILE] Support multi-config in tile_example_gemm_universal Add GemmConfig in run_gemm_example to support multiple tile config. - It is useful when use you need compare gemm perf with different tile/pipeline config - we also can use it simplify the code for wmma support in the furture. * [CK_TILE] Support multi-config in tile_example_gemm_universal Address review comments * rebase code and fix clang format. * fix clang format * support pipeline v5. * fix merge conflict * address review comment * add missing file * address review comment v2 * fix build error --- example/ck_tile/03_gemm/gemm_basic.cpp | 41 +-- example/ck_tile/03_gemm/gemm_utils.hpp | 301 ++++++++++++------ example/ck_tile/03_gemm/run_gemm_example.inc | 40 ++- example/ck_tile/03_gemm/universal_gemm.cpp | 71 +++-- .../gemm/pipeline/gemm_pipeline_problem.hpp | 3 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 5 +- 6 files changed, 306 insertions(+), 155 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 1906b0bda7..090a98486e 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -12,7 +12,8 @@ #include "ck_tile/host.hpp" #include "gemm_utils.hpp" -template + typename CDEElementWise> float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { @@ -140,12 +141,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -156,24 +157,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } else { - if(a_layout == "R" && b_layout == "R") + if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } + else if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -211,15 +212,19 @@ int run_gemm_example(int argc, char* argv[]) return run_gemm_example_prec_type( a_layout, b_layout, argc, argv); } - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType - return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } } -#endif else { throw std::runtime_error("Unsupported data type for this operation !!!"); diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 6987a2492e..101e195903 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -16,105 +16,8 @@ #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 -#endif - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV5 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV5 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#else -#error "unsupported CK_TILE_PIPELINE_DEFAULT value" -#endif - -struct GemmConfig +struct GemmConfigBase { -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 32; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 4; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t NumWaveGroups = 1; -#endif -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) - // Compute friendly for Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; - - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t NumWaveGroups = 1; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 32; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t NumWaveGroups = 1; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5) - // Compute friendly for Intrawave scheduler - // Using the ping pong reader in the lds level - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 32; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 1; - static constexpr ck_tile::index_t K_Warp = 2; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool DoubleSmemBuffer = false; - - // Available wavegroups will be split into `NumWaveGroups` and each of these wavegroups - // will be responsible for specific jobs. For instance, perform Global Memory read operations, - // perform block-gemm operation etc... - static constexpr ck_tile::index_t NumWaveGroups = 2; -#endif - static constexpr bool kPadM = false; static constexpr bool kPadN = false; static constexpr bool kPadK = false; @@ -128,6 +31,169 @@ struct GemmConfig static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 32 : 128; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; template @@ -224,6 +290,45 @@ struct DataTypeTraits static constexpr const char* name = "pk_int4_t"; }; +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index cc9a825c73..140107bfb4 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -30,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template ; - using GemmPipeline = GEMM_PIPELINE; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; const ck_tile::index_t K = tensor.get_length(0); const ck_tile::index_t N = tensor.get_length(1); @@ -144,7 +146,22 @@ void permute_vectors_i4x4_b(Tensor& tensor) } } -template +float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s); + +template b_k_n_dev = b_k_n; if constexpr(GemmConfig::PermuteB) { - permute_tensor_b, AccDataType, diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 3ec90e7f00..ecfaa92b9a 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -13,7 +13,8 @@ #include "gemm_utils.hpp" #include "run_gemm_example.inc" -template + typename CDEElementWise> float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { @@ -45,7 +46,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: GemmConfig::kPadK, ALayout, BLayout, - ELayout>; + ELayout, + GemmConfig::NumWaveGroups>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits& args, const ck_tile: using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; @@ -75,7 +78,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem& args, const ck_tile: has_hot_loop_v, tail_number_v>; - using GemmPipeline = GEMM_PIPELINE; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem& args, const ck_tile: UniversalGemmProblem::TransposeC, memory_operation, GemmConfig::NumWaveGroups>>; - using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -205,7 +208,10 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: return ave_time; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -215,12 +221,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -233,22 +239,22 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "R") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Row{}, Row{}); } else if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_example_with_layouts( argc, argv, Col{}, Col{}, Row{}); } else @@ -258,6 +264,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } } +template