From d7e80d2ae010dbb17ba7c3e30225a40f84552e98 Mon Sep 17 00:00:00 2001 From: ozturkosu Date: Fri, 13 Jun 2025 04:54:51 -0400 Subject: [PATCH] gridwise update --- .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 176 ++++++------------ 1 file changed, 56 insertions(+), 120 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index f1c0ec1c68..bd00a64d7c 100755 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck/utility/common_header.hpp" -#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -140,24 +139,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); - static constexpr bool is_single_rate_mfma = - (((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8) || - ((is_same::value || is_same::value) && - lcm_AK1_BK1 < 32)) - ? true - : false; - static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; __host__ static auto CalculateMPadded(index_t M) @@ -239,23 +223,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }(); - // Pad both M and K to be multiples of the block sizes - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; -#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MKPadding || @@ -322,7 +289,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 return a_grid_desc_ak0_m_ak1; } -#endif } __device__ static auto MakeBGridDescriptor_BK0_N_BK1( @@ -339,23 +305,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }(); - // Pad both N and K to be multiples of the block sizes - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(N, NPad - N), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; -#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::NKPadding || @@ -422,7 +371,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 return b_grid_desc_bk0_n_bk1; } -#endif } template @@ -457,13 +405,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }(); - // Pad both M and N to be multiples of the block sizes - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); -#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MNPadding || @@ -501,7 +442,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 // not pad M or N return c_grid_desc_mraw_nraw; } -#endif } struct Problem @@ -513,8 +453,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 index_t StrideB_, index_t StrideC_, index_t Streamk_sel_, - index_t Grid_size_, - StreamKReductionStrategy reduction_strategy_) + index_t Grid_size_) : M{M_}, N{N_}, K{K_}, @@ -523,7 +462,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 StrideC{StrideC_}, Streamk_sel{Streamk_sel_}, Grid_size{Grid_size_}, - reduction_strategy{reduction_strategy_}, // Initialize the member variable MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, KRead{CalculateKRead(K_, 1)}, @@ -552,13 +490,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << ", " - << "Stream-K Selection:" << Streamk_sel << ", " - << "Grid size:" << Grid_size << ", " - << "Reduction Strategy:" - << (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic" - : "Reduction") - << "}" << std::endl; + << "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel + << ", Grid size:" << Grid_size << "}" << std::endl; } index_t M; @@ -569,7 +502,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 index_t StrideC; index_t Streamk_sel; mutable index_t Grid_size; - StreamKReductionStrategy reduction_strategy; index_t MPadded; index_t NPadded; index_t KRead; @@ -593,26 +525,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 index_t StrideB_, index_t StrideC_, index_t Streamk_sel_, - index_t Grid_size_, - StreamKReductionStrategy reduction_strategy_) - : Problem{M_, - N_, - K_, - StrideA_, - StrideB_, - StrideC_, - Streamk_sel_, - Grid_size_, - reduction_strategy_}, + index_t Grid_size_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, p_c_grid{p_c_grid_}, - block_2_ctile_map_streamk(M_, - N_, - AK0Number * CalculateKPadded(K_, 1), - Grid_size_, - Streamk_sel_, - reduction_strategy_) + block_2_ctile_map_streamk( + M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_), + launch_grid_dims_{0, 0, 0} // Initialize grid dims to zero { } @@ -627,6 +547,18 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 8, 4> block_2_ctile_map_streamk; + + mutable dim3 launch_grid_dims_; + + void SetLaunchGridDims(dim3 dims) const + { + launch_grid_dims_ = dims; + } + + dim3 GetLaunchGridDims() const + { + return launch_grid_dims_; + } }; struct SplitKBatchOffset @@ -1027,8 +959,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - !(is_same::value)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) { if(!(karg.M % MPerBlock == 0)) { @@ -1045,8 +976,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - (is_same::value)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) { if(!(karg.N % NPerBlock == 0)) { @@ -1112,7 +1042,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } - return false; } } @@ -1128,10 +1057,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; return false; } } @@ -1146,7 +1071,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } - return false; } } @@ -1164,7 +1088,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } - return false; } } @@ -1181,11 +1104,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } - return false; } } + if constexpr(is_same, bhalf_t>::value) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + } + // check gridwise gemm pipeline const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); @@ -1288,13 +1220,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size, - problem.Streamk_sel, - problem.reduction_strategy); + problem.Streamk_sel); uint32_t iter_start, iter_end; bool is_sk_block, is_dp_block, is_reduction_block; index_t num_k_block_main_loop; @@ -1309,7 +1239,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 uint32_t* p_semaphore = reinterpret_cast( reinterpret_cast(p_workspace) + block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType))); - for(auto block_idx = get_block_1d_id(); block_idx < block_2_ctile_map_streamk.get_grid_dims(); block_idx += gridDim.x) @@ -1325,7 +1254,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); num_k_block_main_loop = iter_end - iter_start; - if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { is_reduction_block = static_cast(block_idx) >= block_2_ctile_map_streamk.reduction_start_block_idx; @@ -1913,7 +1843,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } else if(is_sk_block) { - if(problem.reduction_strategy == StreamKReductionStrategy::Atomic) + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Atomic) { // each block copy its data from LDS to global c_shuffle_block_copy_lds_to_global @@ -1925,8 +1856,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_buf); } - else if(problem.reduction_strategy == - StreamKReductionStrategy::Reduction) + else if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { // constexpr offset c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( @@ -1958,7 +1889,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }); - if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { if(is_sk_block) { @@ -1973,7 +1905,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 iter_end -= current_iter_length; if(iter_end <= iter_start) break; - if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { block_acc_offset -= MPerBlock * NPerBlock; } @@ -2028,8 +1961,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 problem.N, AK0Number * problem.KPadded, problem.Grid_size, - problem.Streamk_sel, - problem.reduction_strategy); + problem.Streamk_sel); for(auto block_idx = get_block_1d_id(); block_idx < block_2_ctile_map_streamk.get_grid_dims(); block_idx += gridDim.x) @@ -2048,7 +1980,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 reinterpret_cast(p_workspace) + block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType))); - if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { is_reduction_block = static_cast(block_idx) >= block_2_ctile_map_streamk.reduction_start_block_idx; @@ -2664,7 +2597,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } else if(is_sk_block) { - if(problem.reduction_strategy == StreamKReductionStrategy::Atomic) + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Atomic) { // each block copy its data from LDS to global c_shuffle_block_copy_lds_to_global @@ -2676,8 +2610,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_buf); } - else if(problem.reduction_strategy == - StreamKReductionStrategy::Reduction) + else if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { // constexpr offset c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( @@ -2712,14 +2646,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 iter_end -= current_iter_length; if(iter_end <= iter_start) break; - if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { block_acc_offset -= MPerBlock * NPerBlock; } // make sure next loop LDS is ready for use block_sync_lds(); } - if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { if(is_sk_block) {