diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp index c91cb06e6b..2f4926d452 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -57,12 +58,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_batch(i) = karg.p_ds_grid_[i] + ds_batch_offset[i]; }); - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseOp::IsBWaveTransferApplicable && GridwiseOp::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseOp::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseOp::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const auto a_grid_desc_ak0_m_ak1 = @@ -70,7 +73,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto b_grid_desc_bk0_n_bk1 = GridwiseOp::MakeBGridDescriptor_BK0_N_BK1(karg.b_grid_desc_n_k_); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseOp::template Run( p_as_grid_batch, p_bs_grid_batch, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp index ae247f4e31..36035b319a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -61,8 +62,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const long_index_t c_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -81,7 +84,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; }); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp index 5c42dc9745..26891b4367 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp @@ -12,6 +12,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -50,9 +52,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + using SelectedEpilogue = + get_epilogue_t; + constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -85,11 +89,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); - auto epilogue_args = EpilogueType(reduces_batch, - reduce_in_element_ops, - reduce_out_element_ops, - karg.M, - tensor_operation::element_wise::PassThrough{}); + auto epilogue_args = SelectedEpilogue(reduces_batch, + reduce_in_element_ops, + reduce_out_element_ops, + karg.M, + tensor_operation::element_wise::PassThrough{}); GridwiseGemm::template Run( p_as_grid_shift, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp index fb1ca3127e..5503d9a697 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -8,7 +8,7 @@ #include "ck/host_utility/flush_cache.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp" #include #include diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp index b6d52b00dc..c84fa51a40 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -12,6 +12,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -47,14 +49,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + using SelectedEpilogue = + get_epilogue_t; constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = EpilogueType( + auto epilogue_args = SelectedEpilogue( p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M, d0_element_op); GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp index 13bba7626d..c1e9048d1b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp" #include "ck/host_utility/device_prop.hpp" @@ -48,14 +49,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueWelfordCShuffle>(); + using SelectedEpilogue = get_epilogue_t; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = typename GridwiseGemm::EpilogueWelfordCShuffle( + auto epilogue_args = SelectedEpilogue( p_welford_mean_grid, p_welford_var_grid, p_welford_count_grid, karg.M, karg.N); GridwiseGemm::template Run( @@ -298,14 +300,16 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence{}); } + using SelectedEpilogue = get_epilogue_t; + using LayernormMeanVarGridDesc_M_NBlock = - decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N< + decltype(SelectedEpilogue::template MakeMeanVarDescriptor_M_N< Sequence, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)); using LayernormCountGridDesc_M_NBlock = - decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N< + decltype(SelectedEpilogue::template MakeCountDescriptor_M_N< Sequence, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)); @@ -398,13 +402,13 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid_[i] = p_ds_grid[i]; }); layernorm_mean_var_grid_desc_m_nblock_ = - GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N< + SelectedEpilogue::template MakeMeanVarDescriptor_M_N< Sequence, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_); layernorm_count_grid_desc_m_nblock_ = - GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N< + SelectedEpilogue::template MakeCountDescriptor_M_N< Sequence, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 8fa090e61a..e223cb5fbe 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -12,6 +12,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -46,18 +48,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + using SelectedEpilogue = + get_epilogue_t; constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = EpilogueType(p_reduces_grid, - reduce_in_element_ops, - reduce_out_element_ops, - karg.M, - tensor_operation::element_wise::PassThrough{}); + auto epilogue_args = SelectedEpilogue(p_reduces_grid, + reduce_in_element_ops, + reduce_out_element_ops, + karg.M, + tensor_operation::element_wise::PassThrough{}); GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index c09befa717..e453247412 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -13,7 +13,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 83665210ae..bc1a12d63b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" @@ -72,9 +73,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(EGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif - __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>()]; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + auto epilogue_args = SelectedEpilogue{}; const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); index_t left = 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index ed0378e23f..220a5de699 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -18,6 +18,7 @@ #include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include @@ -68,13 +69,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4}; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" @@ -69,12 +70,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; + using SelectedEpilogue = get_epilogue_t; const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4}; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = SelectedEpilogue{}; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; GridwiseGemm::template Run @@ -68,12 +69,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; + using SelectedEpilogue = get_epilogue_t; const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4}; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = SelectedEpilogue{}; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; GridwiseGemm::template Run))) { #endif - using EpilogueType = - typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp index 8a47abc845..5b3aa82940 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp @@ -20,6 +20,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" @@ -50,8 +51,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const ComputePtrOffset compute_ptr_offset_of_n) { #if defined(__gfx11__) || defined(__gfx12__) - using Epilogue = typename GridwiseGemm::EpilogueCShuffle; - __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte()]; + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); @@ -147,7 +151,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const index_t num_k_block_per_scale = GridwiseGemm::GetKBlockPerScale(); - auto epilogue_args = Epilogue{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::Base::template Run(p_as_grid_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index 87be350a44..892ba4af54 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -22,6 +22,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -57,12 +58,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const CDEElementwiseOperation cde_element_op) { #if defined(__gfx11__) || defined(__gfx12__) - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t KBatch = 1; @@ -139,13 +142,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto block_2_etile_map = GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run( p_as_grid_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp index c87bef3b93..d20fc1ec5b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp @@ -17,6 +17,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" @@ -66,12 +67,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CDEElementwiseOperation cde_element_op) { #if(defined(__gfx11__) || defined(__gfx12__)) - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[LDS_size]; const auto gemm_desc_ptr = @@ -154,7 +157,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) gemm_desc_ptr[group_id].StrideE, 1); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; constexpr TailNumber TailNum = TailNumber::Full; if(has_main_k_block_loop) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 10490fa831..fe04c9bbde 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -22,6 +22,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" namespace ck { namespace tensor_operation { @@ -94,12 +95,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { #if(defined(__gfx11__) || defined(__gfx12__)) - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -179,13 +182,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run(static_cast(p_shared), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 99a18e07fc..86ed2270a6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -41,12 +42,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t group_count) { #if(defined(__gfx11__) || defined(__gfx12__)) - using EpilogueType = typename std::conditional::type; + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -93,13 +96,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); - auto epilogue_args = EpilogueType{}; + auto epilogue_args = SelectedEpilogue{}; GridwiseGemm::template Run(static_cast(p_shared), diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_type.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_type.hpp new file mode 100644 index 0000000000..aaa88dd9ef --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_type.hpp @@ -0,0 +1,125 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp" + +namespace ck { + +enum class EpilogueType +{ + CShuffle = 0, + DirectStore, + WelfordCShuffle, + ReduceCShuffle +}; + +template > +struct get_epilogue +{ + private: + static constexpr auto get_epilogue_implementation() + { + static_assert((type == EpilogueType::ReduceCShuffle) == + (!std::is_same_v>), + "Provide a ReduceTrait only if the desired epilogue type is ReduceCShuffle."); + using TypeExtractor = typename GridwiseGemm::Traits; + + if constexpr(type == EpilogueType::CShuffle) + { + return EpilogueCShuffle< + typename TypeExtractor::DsDataType_, + typename TypeExtractor::EDataType_, + typename TypeExtractor::AccDataType_, + typename TypeExtractor::CShuffleDataType_, + TypeExtractor::MPerBlock_, + TypeExtractor::NPerBlock_, + TypeExtractor::MPerWmma_, + TypeExtractor::NPerWmma_, + TypeExtractor::MRepeat_, + TypeExtractor::NRepeat_, + TypeExtractor::CShuffleMRepeatPerShuffle_, + TypeExtractor::CShuffleNRepeatPerShuffle_, + typename TypeExtractor:: + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, + typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_, + typename TypeExtractor::CDEElementwiseOperation_, + typename TypeExtractor::ThisThreadBlock_, + typename TypeExtractor::BlockwiseGemmPipe_>{}; + } + else if constexpr(type == EpilogueType::DirectStore) + { + return EpilogueDirectStore{}; + } + else if constexpr(type == EpilogueType::WelfordCShuffle) + { + return EpilogueWelfordCShuffle< + typename TypeExtractor::DsDataType_, + typename TypeExtractor::EDataType_, + typename TypeExtractor::AccDataType_, + typename TypeExtractor::CShuffleDataType_, + TypeExtractor::MPerBlock_, + TypeExtractor::NPerBlock_, + TypeExtractor::MPerWmma_, + TypeExtractor::NPerWmma_, + TypeExtractor::MRepeat_, + TypeExtractor::NRepeat_, + TypeExtractor::CShuffleMRepeatPerShuffle_, + TypeExtractor::CShuffleNRepeatPerShuffle_, + typename TypeExtractor:: + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, + typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_, + typename TypeExtractor::CDEElementwiseOperation_, + typename TypeExtractor::ThisThreadBlock_, + typename TypeExtractor::BlockwiseGemmPipe_, + TypeExtractor::BlockSize_>{{}, {}, {}, {}, {}}; + } + else if constexpr(type == EpilogueType::ReduceCShuffle) + { + return EpilogueReduceCShuffle< + typename TypeExtractor::DsDataType_, + typename TypeExtractor::EDataType_, + typename TypeExtractor::AccDataType_, + typename TypeExtractor::CShuffleDataType_, + TypeExtractor::MPerBlock_, + TypeExtractor::NPerBlock_, + TypeExtractor::MPerWmma_, + TypeExtractor::NPerWmma_, + TypeExtractor::MRepeat_, + TypeExtractor::NRepeat_, + TypeExtractor::CShuffleMRepeatPerShuffle_, + TypeExtractor::CShuffleNRepeatPerShuffle_, + typename TypeExtractor:: + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_, + typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_, + typename TypeExtractor::CDEElementwiseOperation_, + typename TypeExtractor::ThisThreadBlock_, + typename TypeExtractor::BlockwiseGemmPipe_, + TypeExtractor::GemmSpec_, + TypeExtractor::BlockSize_, + ReduceTrait>{{}, {}, {}, {}, {}}; + } + else + { + static_assert(false, "Not implemented for the specified type."); + } + } + + public: + using Type = decltype(get_epilogue_implementation()); +}; + +template > +using get_epilogue_t = typename get_epilogue::Type; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 41c76a1c91..93761228f5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -19,10 +19,6 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" @@ -33,214 +29,6 @@ namespace ck { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - auto epilogue_args = EpilogueType{}; - - GridwiseGemm::template Run( - p_shared, splitk_batch_offset, karg, epilogue_args); - -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_batched_gemm_wmma_cshuffle_v3( - typename GridwiseGemm::Argument karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - // The normal approach to batching would be to increase the grid size by just stretching out - // the grid Z dimension (which is the outermost dimension), but this depends on lower level - // functions not directly using the Z dimension for other calculations. As it turns out, k - // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now - // we will use the grid Y dimension for batching. This may be a bit fragile. - const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - // shift A matrices pointer for splitk - typename GridwiseGemm::AsGridPointer p_as_grid_shift; - static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { - using ADataType_ = - remove_cvref_t>; - p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + - splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; - }); - - // shift B matrices pointer for splitk - typename GridwiseGemm::BsGridPointer p_bs_grid_shift; - static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { - using BDataType_ = - remove_cvref_t>; - p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + - splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; - }); - - auto epilogue_args = EpilogueType{}; - - if constexpr(IsBScaled) - { - const long_index_t b_scale_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); - - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - karg.p_a_scale_grid, - karg.p_b_scale_grid + b_scale_batch_offset + - splitk_batch_offset.scale_b_k_split_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); - } - else - { - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); - } -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; - ignore = compute_ptr_offset_of_batch; -#endif -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); - const index_t k_id = blockIdx.z * num_k_per_block; - - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run( - p_shared, - splitk_batch_offset, - karg, - epilogue_args, - 0, /* A_k_id == 0 (we shift the pointer for splitk) */ - k_id); - -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - template ())>; - // Used to create obj in global function and pass it to Run method - using EpilogueCShuffle = - EpilogueCShuffle; + struct Traits + { + using DsDataType_ = DsDataType; + using EDataType_ = EDataType; + using AccDataType_ = AccDataType; + using CShuffleDataType_ = CShuffleDataType; + using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ = + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors; + using CDEElementwiseOperation_ = CDEElementwiseOperation; + using ThisThreadBlock_ = ThisThreadBlock; + using BlockwiseGemmPipe_ = BlockwiseGemmPipe; - using EpilogueDirectStore = EpilogueDirectStore; - - using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle< - DsDataType, - EDataType, - AccDataType, - CShuffleDataType, - MPerBlock, - NPerBlock, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, - CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - CDEElementwiseOperation, - ThisThreadBlock, - BlockwiseGemmPipe, - BlockSize>; - - template - using EpilogueReduceCShuffle = EpilogueReduceCShuffle< - DsDataType, - EDataType, - AccDataType, - CShuffleDataType, - MPerBlock, - NPerBlock, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, - CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - CDEElementwiseOperation, - ThisThreadBlock, - BlockwiseGemmPipe, - GemmSpec, - BlockSize, - ReduceTrait>; + static constexpr auto MPerBlock_ = MPerBlock; + static constexpr auto NPerBlock_ = NPerBlock; + static constexpr auto MPerWmma_ = MPerWmma; + static constexpr auto NPerWmma_ = NPerWmma; + static constexpr auto MRepeat_ = MRepeat; + static constexpr auto NRepeat_ = NRepeat; + static constexpr auto CShuffleMRepeatPerShuffle_ = CShuffleMRepeatPerShuffle; + static constexpr auto CShuffleNRepeatPerShuffle_ = CShuffleNRepeatPerShuffle; + static constexpr auto GemmSpec_ = GemmSpec; + static constexpr auto BlockSize_ = BlockSize; + }; template __host__ __device__ static constexpr auto @@ -1324,7 +1066,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); } - template + template __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment @@ -1346,11 +1088,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base max_lds_align) : 0; - if constexpr(EpilogueType::IsLDSNeeded()) + if constexpr(Epilogue::IsLDSNeeded()) { // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - EpilogueType:: + Epilogue:: GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); constexpr auto c_block_size = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp new file mode 100644 index 0000000000..d5837575bb --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp @@ -0,0 +1,223 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/scheduler_enum.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + auto epilogue_args = SelectedEpilogue{}; + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + constexpr auto epilogue_type = + GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore + ? EpilogueType::DirectStore + : EpilogueType::CShuffle; + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + // shift A matrices pointer for splitk + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + // shift B matrices pointer for splitk + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + + auto epilogue_args = SelectedEpilogue{}; + + if constexpr(IsBScaled) + { + const long_index_t b_scale_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); + + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + karg.p_a_scale_grid, + karg.p_b_scale_grid + b_scale_batch_offset + + splitk_batch_offset.scale_b_k_split_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } + else + { + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + + using SelectedEpilogue = get_epilogue_t; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); + const index_t k_id = blockIdx.z * num_k_per_block; + + auto epilogue_args = SelectedEpilogue{}; + + GridwiseGemm::template Run( + p_shared, + splitk_batch_offset, + karg, + epilogue_args, + 0, /* A_k_id == 0 (we shift the pointer for splitk) */ + k_id); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +} // namespace ck