From c7fac341de038607ca775bded2b7e324aa1de387 Mon Sep 17 00:00:00 2001 From: chris-tsiaousis-hpc Date: Fri, 22 May 2026 20:39:01 +0200 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#4871 (commit 7d4c040) [CK] Decouple EpilogueArgs from GridwiseGemm implementation (#4871) This is duplicate of #4537. I could not re-open it since te target branch got deleted and could not change the target branch since it was closed... :) ## Motivation Right now, all the Epilogues structs are declared inside the base gridwise struct. They should be independent of it and the specialization of the selected Epilogue Type should be declared within the the kernel function. ## Technical Details All Epilogue structs depend on template parameters that are known to the base Gridwise Gemm struct. In this PR, we export them to be used independently by any struct that might need to extract them. This approach will serve the decoupling purposes for the Epilogues, but also enable future constructs to use and expand this approach. See 30e2a4c01b64bdea68857c7badd9d7cffbf1adb9. Right now an issue that arises is that when implementing a new Epilogue Type, the developer is not forced to decide where this struct should/can be used or not. To fix this I propose defining an `enum struct EpilogueType` that will be used to fetch the Epilogue specialization through a helper struct. See a943ac8d130e12d6843715b322181186e54ba15c. Note that all the instantiation details will stay in this helper struct. Also note the static assertion in the else statement. ## Test Plan Test with existing CI, as nothing is added/removed. ## Test Result All relevant existing CI tests should pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Signed-off-by: Chris Tsiaousis --- ...ontraction_multiple_d_wmma_cshuffle_v3.hpp | 15 +- ...tched_gemm_multiple_d_wmma_cshuffle_v3.hpp | 9 +- ...e_batched_gemm_reduce_wmma_cshuffle_v3.hpp | 18 +- ...e_batched_gemm_wmma_cshuffle_v3_common.hpp | 2 +- ..._gemm_bias_add_reduce_wmma_cshuffle_v3.hpp | 9 +- ..._multiple_d_layernorm_wmma_cshuffle_v3.hpp | 18 +- .../device_gemm_reduce_wmma_cshuffle_v3.hpp | 17 +- .../device_gemm_wmma_cshuffle_v3_common.hpp | 2 +- ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 10 +- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 8 +- ..._bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 11 +- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 11 +- ...conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp | 15 +- ...ltiple_d_wmma_cshuffle_v3_large_tensor.hpp | 12 +- ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 17 +- ..._multiple_d_wmma_cshuffle_tile_loop_v3.hpp | 15 +- .../device_grouped_gemm_wmma_fixed_nk.hpp | 17 +- ...e_grouped_gemm_wmma_splitk_cshuffle_v3.hpp | 17 +- .../gpu/grid/epilogue_type.hpp | 125 +++++++ .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 310 ++---------------- ...e_gemm_wmma_cshuffle_v3_common_kernels.hpp | 223 +++++++++++++ 21 files changed, 510 insertions(+), 371 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/grid/epilogue_type.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp index c91cb06e6b..2f4926d452 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -57,12 +58,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_batch(i) = karg.p_ds_grid_[i] + ds_batch_offset[i]; }); - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseOp::IsBWaveTransferApplicable && GridwiseOp::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseOp::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseOp::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const auto a_grid_desc_ak0_m_ak1 = @@ -70,7 +73,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto b_grid_desc_bk0_n_bk1 = GridwiseOp::MakeBGridDescriptor_BK0_N_BK1(karg.b_grid_desc_n_k_); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseOp::template Run( p_as_grid_batch, p_bs_grid_batch, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp index ae247f4e31..36035b319a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -61,8 +62,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const long_index_t c_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -81,7 +84,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; }); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp index 5c42dc9745..26891b4367 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp @@ -12,6 +12,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -50,9 +52,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + using SelectedEpilogue = + get_epilogue_t; + constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -85,11 +89,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); - auto epilogue_args = EpilogueType(reduces_batch, - reduce_in_element_ops, - reduce_out_element_ops, - karg.M, - tensor_operation::element_wise::PassThrough{}); + auto epilogue_args = SelectedEpilogue(reduces_batch, + reduce_in_element_ops, + reduce_out_element_ops, + karg.M, + tensor_operation::element_wise::PassThrough{}); GridwiseGemm::template Run( p_as_grid_shift, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp index fb1ca3127e..5503d9a697 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -8,7 +8,7 @@ #include "ck/host_utility/flush_cache.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp" #include #include diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp index b6d52b00dc..c84fa51a40 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -12,6 +12,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -47,14 +49,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + using SelectedEpilogue = + get_epilogue_t; constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = EpilogueType( + auto epilogue_args = SelectedEpilogue( p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M, d0_element_op); GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp index 13bba7626d..c1e9048d1b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp" #include "ck/host_utility/device_prop.hpp" @@ -48,14 +49,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueWelfordCShuffle>(); + using SelectedEpilogue = get_epilogue_t; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = typename GridwiseGemm::EpilogueWelfordCShuffle( + auto epilogue_args = SelectedEpilogue( p_welford_mean_grid, p_welford_var_grid, p_welford_count_grid, karg.M, karg.N); GridwiseGemm::template Run( @@ -298,14 +300,16 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence{}); } + using SelectedEpilogue = get_epilogue_t; + using LayernormMeanVarGridDesc_M_NBlock = - decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N< + decltype(SelectedEpilogue::template MakeMeanVarDescriptor_M_N< Sequence, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)); using LayernormCountGridDesc_M_NBlock = - decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N< + decltype(SelectedEpilogue::template MakeCountDescriptor_M_N< Sequence, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)); @@ -398,13 +402,13 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid_[i] = p_ds_grid[i]; }); layernorm_mean_var_grid_desc_m_nblock_ = - GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N< + SelectedEpilogue::template MakeMeanVarDescriptor_M_N< Sequence, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_); layernorm_count_grid_desc_m_nblock_ = - GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N< + SelectedEpilogue::template MakeCountDescriptor_M_N< Sequence, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 8fa090e61a..e223cb5fbe 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -12,6 +12,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -46,18 +48,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + using SelectedEpilogue = + get_epilogue_t; constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = EpilogueType(p_reduces_grid, - reduce_in_element_ops, - reduce_out_element_ops, - karg.M, - tensor_operation::element_wise::PassThrough{}); + auto epilogue_args = SelectedEpilogue(p_reduces_grid, + reduce_in_element_ops, + reduce_out_element_ops, + karg.M, + tensor_operation::element_wise::PassThrough{}); GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index c09befa717..e453247412 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -13,7 +13,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 83665210ae..bc1a12d63b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" @@ -72,9 +73,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(EGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif - __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>()]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + auto epilogue_args = SelectedEpilogue{}; const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); index_t left = 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index ed0378e23f..220a5de699 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -18,6 +18,7 @@ #include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include @@ -68,13 +69,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4}; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" @@ -69,12 +70,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; + using SelectedEpilogue = get_epilogue_t; const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4}; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = SelectedEpilogue{}; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; GridwiseGemm::template Run @@ -68,12 +69,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; + using SelectedEpilogue = get_epilogue_t; const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4}; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = SelectedEpilogue{}; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; GridwiseGemm::template Run))) { #endif - using EpilogueType = - typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp index 8a47abc845..5b3aa82940 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp @@ -20,6 +20,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" @@ -50,8 +51,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffset compute_ptr_offset_of_n) { #if defined(__gfx11__) || defined(__gfx12__) - using Epilogue = typename GridwiseGemm::EpilogueCShuffle; - __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte()]; + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); @@ -147,7 +151,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const index_t num_k_block_per_scale = GridwiseGemm::GetKBlockPerScale(); - auto epilogue_args = Epilogue{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::Base::template Run(p_as_grid_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index 87be350a44..892ba4af54 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -22,6 +22,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -57,12 +58,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const CDEElementwiseOperation cde_element_op) { #if defined(__gfx11__) || defined(__gfx12__) - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t KBatch = 1; @@ -139,13 +142,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto block_2_etile_map = GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run( p_as_grid_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp index c87bef3b93..d20fc1ec5b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp @@ -17,6 +17,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" @@ -66,12 +67,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CDEElementwiseOperation cde_element_op) { #if(defined(__gfx11__) || defined(__gfx12__)) - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[LDS_size]; const auto gemm_desc_ptr = @@ -154,7 +157,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) gemm_desc_ptr[group_id].StrideE, 1); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; constexpr TailNumber TailNum = TailNumber::Full; if(has_main_k_block_loop) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 10490fa831..fe04c9bbde 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -22,6 +22,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" namespace ck { namespace tensor_operation { @@ -94,12 +95,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { #if(defined(__gfx11__) || defined(__gfx12__)) - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -179,13 +182,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run(static_cast(p_shared), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 99a18e07fc..86ed2270a6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -41,12 +42,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t group_count) { #if(defined(__gfx11__) || defined(__gfx12__)) - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -93,13 +96,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run(static_cast(p_shared), diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_type.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_type.hpp new file mode 100644 index 0000000000..aaa88dd9ef --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_type.hpp @@ -0,0 +1,125 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp" + +namespace ck { + +enum class EpilogueType +{ + CShuffle = 0, + DirectStore, + WelfordCShuffle, + ReduceCShuffle +}; + +template > +struct get_epilogue +{ + private: + static constexpr auto get_epilogue_implementation() + { + static_assert((type == EpilogueType::ReduceCShuffle) == + (!std::is_same_v>), + "Provide a ReduceTrait only if the desired epilogue type is ReduceCShuffle."); + using TypeExtractor = typename GridwiseGemm::Traits; + + if constexpr(type == EpilogueType::CShuffle) + { + return EpilogueCShuffle< + typename TypeExtractor::DsDataType_, + typename TypeExtractor::EDataType_, + typename TypeExtractor::AccDataType_, + typename TypeExtractor::CShuffleDataType_, + TypeExtractor::MPerBlock_, + TypeExtractor::NPerBlock_, + TypeExtractor::MPerWmma_, + TypeExtractor::NPerWmma_, + TypeExtractor::MRepeat_, + TypeExtractor::NRepeat_, + TypeExtractor::CShuffleMRepeatPerShuffle_, + TypeExtractor::CShuffleNRepeatPerShuffle_, + typename TypeExtractor:: + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, + typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_, + typename TypeExtractor::CDEElementwiseOperation_, + typename TypeExtractor::ThisThreadBlock_, + typename TypeExtractor::BlockwiseGemmPipe_>{}; + } + else if constexpr(type == EpilogueType::DirectStore) + { + return EpilogueDirectStore{}; + } + else if constexpr(type == EpilogueType::WelfordCShuffle) + { + return EpilogueWelfordCShuffle< + typename TypeExtractor::DsDataType_, + typename TypeExtractor::EDataType_, + typename TypeExtractor::AccDataType_, + typename TypeExtractor::CShuffleDataType_, + TypeExtractor::MPerBlock_, + TypeExtractor::NPerBlock_, + TypeExtractor::MPerWmma_, + TypeExtractor::NPerWmma_, + TypeExtractor::MRepeat_, + TypeExtractor::NRepeat_, + TypeExtractor::CShuffleMRepeatPerShuffle_, + TypeExtractor::CShuffleNRepeatPerShuffle_, + typename TypeExtractor:: + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, + typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_, + typename TypeExtractor::CDEElementwiseOperation_, + typename TypeExtractor::ThisThreadBlock_, + typename TypeExtractor::BlockwiseGemmPipe_, + TypeExtractor::BlockSize_>{{}, {}, {}, {}, {}}; + } + else if constexpr(type == EpilogueType::ReduceCShuffle) + { + return EpilogueReduceCShuffle< + typename TypeExtractor::DsDataType_, + typename TypeExtractor::EDataType_, + typename TypeExtractor::AccDataType_, + typename TypeExtractor::CShuffleDataType_, + TypeExtractor::MPerBlock_, + TypeExtractor::NPerBlock_, + TypeExtractor::MPerWmma_, + TypeExtractor::NPerWmma_, + TypeExtractor::MRepeat_, + TypeExtractor::NRepeat_, + TypeExtractor::CShuffleMRepeatPerShuffle_, + TypeExtractor::CShuffleNRepeatPerShuffle_, + typename TypeExtractor:: + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, + typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_, + typename TypeExtractor::CDEElementwiseOperation_, + typename TypeExtractor::ThisThreadBlock_, + typename TypeExtractor::BlockwiseGemmPipe_, + TypeExtractor::GemmSpec_, + TypeExtractor::BlockSize_, + ReduceTrait>{{}, {}, {}, {}, {}}; + } + else + { + static_assert(false, "Not implemented for the specified type."); + } + } + + public: + using Type = decltype(get_epilogue_implementation()); +}; + +template > +using get_epilogue_t = typename get_epilogue::Type; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 41c76a1c91..93761228f5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -19,10 +19,6 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" @@ -33,214 +29,6 @@ namespace ck { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - auto epilogue_args = EpilogueType{}; - - GridwiseGemm::template Run( - p_shared, splitk_batch_offset, karg, epilogue_args); - -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_batched_gemm_wmma_cshuffle_v3( - typename GridwiseGemm::Argument karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - // The normal approach to batching would be to increase the grid size by just stretching out - // the grid Z dimension (which is the outermost dimension), but this depends on lower level - // functions not directly using the Z dimension for other calculations. As it turns out, k - // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now - // we will use the grid Y dimension for batching. This may be a bit fragile. - const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - // shift A matrices pointer for splitk - typename GridwiseGemm::AsGridPointer p_as_grid_shift; - static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { - using ADataType_ = - remove_cvref_t>; - p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + - splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; - }); - - // shift B matrices pointer for splitk - typename GridwiseGemm::BsGridPointer p_bs_grid_shift; - static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { - using BDataType_ = - remove_cvref_t>; - p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + - splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; - }); - - auto epilogue_args = EpilogueType{}; - - if constexpr(IsBScaled) - { - const long_index_t b_scale_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); - - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - karg.p_a_scale_grid, - karg.p_b_scale_grid + b_scale_batch_offset + - splitk_batch_offset.scale_b_k_split_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); - } - else - { - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); - } -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; - ignore = compute_ptr_offset_of_batch; -#endif -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); - const index_t k_id = blockIdx.z * num_k_per_block; - - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run( - p_shared, - splitk_batch_offset, - karg, - epilogue_args, - 0, /* A_k_id == 0 (we shift the pointer for splitk) */ - k_id); - -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - template ())>; - // Used to create obj in global function and pass it to Run method - using EpilogueCShuffle = - EpilogueCShuffle; + struct Traits + { + using DsDataType_ = DsDataType; + using EDataType_ = EDataType; + using AccDataType_ = AccDataType; + using CShuffleDataType_ = CShuffleDataType; + using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ = + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors; + using CDEElementwiseOperation_ = CDEElementwiseOperation; + using ThisThreadBlock_ = ThisThreadBlock; + using BlockwiseGemmPipe_ = BlockwiseGemmPipe; - using EpilogueDirectStore = EpilogueDirectStore; - - using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle< - DsDataType, - EDataType, - AccDataType, - CShuffleDataType, - MPerBlock, - NPerBlock, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, - CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - CDEElementwiseOperation, - ThisThreadBlock, - BlockwiseGemmPipe, - BlockSize>; - - template - using EpilogueReduceCShuffle = EpilogueReduceCShuffle< - DsDataType, - EDataType, - AccDataType, - CShuffleDataType, - MPerBlock, - NPerBlock, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, - CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - CDEElementwiseOperation, - ThisThreadBlock, - BlockwiseGemmPipe, - GemmSpec, - BlockSize, - ReduceTrait>; + static constexpr auto MPerBlock_ = MPerBlock; + static constexpr auto NPerBlock_ = NPerBlock; + static constexpr auto MPerWmma_ = MPerWmma; + static constexpr auto NPerWmma_ = NPerWmma; + static constexpr auto MRepeat_ = MRepeat; + static constexpr auto NRepeat_ = NRepeat; + static constexpr auto CShuffleMRepeatPerShuffle_ = CShuffleMRepeatPerShuffle; + static constexpr auto CShuffleNRepeatPerShuffle_ = CShuffleNRepeatPerShuffle; + static constexpr auto GemmSpec_ = GemmSpec; + static constexpr auto BlockSize_ = BlockSize; + }; template __host__ __device__ static constexpr auto @@ -1324,7 +1066,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); } - template + template __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment @@ -1346,11 +1088,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base max_lds_align) : 0; - if constexpr(EpilogueType::IsLDSNeeded()) + if constexpr(Epilogue::IsLDSNeeded()) { // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - EpilogueType:: + Epilogue:: GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); constexpr auto c_block_size = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp new file mode 100644 index 0000000000..d5837575bb --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp @@ -0,0 +1,223 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/scheduler_enum.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + auto epilogue_args = SelectedEpilogue{}; + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + // shift A matrices pointer for splitk + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + // shift B matrices pointer for splitk + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + + auto epilogue_args = SelectedEpilogue{}; + + if constexpr(IsBScaled) + { + const long_index_t b_scale_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); + + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + karg.p_a_scale_grid, + karg.p_b_scale_grid + b_scale_batch_offset + + splitk_batch_offset.scale_b_k_split_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } + else + { + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); + const index_t k_id = blockIdx.z * num_k_per_block; + + auto epilogue_args = SelectedEpilogue{}; + + GridwiseGemm::template Run( + p_shared, + splitk_batch_offset, + karg, + epilogue_args, + 0, /* A_k_id == 0 (we shift the pointer for splitk) */ + k_id); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +} // namespace ck