From 513f92f5b989f76802e13a9ecef7bbbc6b64da3e Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 21 May 2025 02:40:20 -0500 Subject: [PATCH] update buffer load to lds feature, build passed --- .../moe_gemm2_xdl_mx_fp4_bns.cpp | 11 +- ...nsor_slice_transfer_gather_direct_load.hpp | 336 +++++++++ .../device/impl/device_moe_mx_gemm_bns.hpp | 40 +- .../gpu/grid/gridwise_moe_mx_gemm_bns.hpp | 675 ++++++++++-------- 4 files changed, 720 insertions(+), 342 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 97a60ee6b2..33d0ec6713 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -40,7 +40,7 @@ using B0DataType = F4; using B1DataType = XPackedDataType; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -62,8 +62,8 @@ struct MulABScaleExpertWeight operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; // for real kernel use template <> - __host__ __device__ constexpr void operator()( - EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const { (void)d0; (void)d1; @@ -580,7 +580,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_n({tokens, N}); + Tensor c_t_n({tokens, N}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm2 +struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + static constexpr auto block_slice_lengths = BlockSliceLengths{}; + static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; + + static constexpr auto thread_single_load_size = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + // After a load, each thread moves by `thread_steps` instead of loading the next elements. + // It makes the whole wavefront load contiguous memory, what is required for direct loads. + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; + static constexpr index_t gather_num = thread_slice_lengths.At(Number{}); + + static __device__ constexpr bool AreThreadClusterLengthsValid() + { + // Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to + // LDS by the threads from a single wavefront. + // Examples (assuming 64 threads in a wavefront, 128 in a thread block): + // 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp32 -> ScalarPerVector = 1 + // INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of + // [0, 4, 0]. + // VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration, + // threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs). + // 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp16 -> ScalarPerVector = 2 + // NOTE: ThreadClusterLengths must take into account that each thread writes two + // elements (single DWORD) along the contiguous dimension. + // INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write + // 8 * 2 elements of K1PerBlock and there are only 8; + // ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32 + // writes [1, 0, 0] instead of [0, 8, 0]. + // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the + // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive + // elements = 64 consecutive DWORDs. + int num_contiguous_dwords = 4; + bool is_contiguous = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(is_contiguous) + { + num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1]; + } + if(thread_slice_lengths[nDim - i - 1] > 1) + { + CK_PRINT>(); + is_contiguous = false; + } + }); + constexpr index_t wavefront_size = get_warp_size(); + const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0; + + bool thread_slice_lengths_correct = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(thread_slice_lengths[i] <= 0) + { + thread_slice_lengths_correct = false; + } + }); + + return wave_contiguous && thread_slice_lengths_correct; + } + + __device__ constexpr ThreadGroupTensorSliceTransfer_Gather_DirectLoad( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const StaticallyIndexedArray& gather_offsets) + : gather_offsets_(gather_offsets) + { + static_assert(ck::is_same_v, + "Direct load transfer does not support datatypes conversion. Source and " + "destination data types must be the same."); + + static_assert( + DstVectorDim == nDim - 1, + "Direct load transfer requires the destination vector dimension to be the last one."); + + static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim, + "When loading more than one element per thread at once, the contiguous " + "dimension must be the same between source and destination."); + + // constexpr auto dword_bytes = 4; + // constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData); + // static_assert(bytes_per_thread_load == dword_bytes, + // "Direct load transfer requires each thread to load exactly a single " + // "DWORD of data."); + + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size(), + "Inconsistent number of dimensions across lengths and descriptors."); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "The number of threads cannot be less than the number of elements in " + "thread cluster lengths."); + + // static_assert( + // AreThreadClusterLengthsValid(), + // "Thread cluster lengths are incorrect. They must be set in a way that allows a single + // " "wavefront to write contiguous DWORDs into LDS memory. "); + + const auto thread_cluster_idx = + thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; + + SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + auto adjusted_src_origin_idx = [&]() { + Index idx; + static_for<0, nDim, 1>{}([&](auto i) { + idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number{}]; + }); + return idx; + }(); + + src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx); + src_slice_origin_ = adjusted_src_origin_idx; + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst_slice_origin_ = dst_slice_origin_idx; + } + + __device__ void ResetDstSliceWindow(const DstDesc& dst_desc) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global, + "Source data must come from a global memory buffer."); + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "Destination data must be stored in an LDS memory buffer."); + + static_assert( + ck::is_same_v, remove_cvref_t>, + "SrcBuffer and SrcData data types must be consistent."); + static_assert( + ck::is_same_v, remove_cvref_t>, + "DstBuffer and DstData data types must be consistent."); + + constexpr auto dst_access_lengths = thread_slice_lengths; + + const auto dst_forward_steps = generate_steps(dst_desc, 1); + const auto dst_backward_steps = generate_steps(dst_desc, -1); + const auto src_forward_steps = generate_steps(src_desc, 1); + const auto src_backward_steps = generate_steps(src_desc, -1); + + // Loop over the destination block and copy data. + static_ford{}([&](auto ordered_dst_access_idx) { + // CK_PRINT(); + auto gather_offset = gather_offsets_(Number{}); + const auto src_offset = src_coord_.GetOffset() + gather_offset; + const auto dst_offset = dst_coord_.GetOffset(); + // printf("Tid: %03d, src_offset: %d, dst_offset: %d\n", get_thread_local_1d_id(), + // src_coord_.GetOffset(), dst_coord_.GetOffset()); + // Check if src data is not in the logic padding area. + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + src_buf.template DirectCopyToLds, ScalarPerVector>( + dst_buf, src_offset, dst_offset, is_src_valid); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1; + }); + move_on_dim_(i) &= i.value != GatherDim; + }); + + return move_on_dim_; + } + (); + + // Decide whether to move forward or backward. + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]); + move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]); + } + else + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]); + move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]); + } + } + }); + }); + + // Reset the destination slice since the entire buffer has been already filled. + ResetDstSliceWindow(dst_desc); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + src_slice_origin_ = src_slice_origin_ + step; + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_); + } + + template + __device__ auto generate_steps(const DescType& desc, int sign) + { + return generate_tuple( + [&](auto i) { + Index step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0; + }); + + return make_tensor_coordinate_step(desc, step_idx); + }, + Number{}); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + SrcCoord src_coord_; + DstCoord dst_coord_; + Index src_slice_origin_; + Index dst_slice_origin_; + StaticallyIndexedArray gather_offsets_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp index 773ccb9fba..c49996c4d5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp @@ -299,20 +299,20 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle; + const auto kernel = kernel_moe_mxgemm_2lds; RunKernel(kernel); } else { - const auto kernel = kernel_moe_mxgemm; + const auto kernel = kernel_moe_mxgemm_2lds; RunKernel(kernel); } } @@ -348,20 +348,20 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle; + const auto kernel = kernel_moe_mxgemm_2lds; RunKernel(kernel); } else { - const auto kernel = kernel_moe_mxgemm; + const auto kernel = kernel_moe_mxgemm_2lds; RunKernel(kernel); } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index ee0bfb4a3f..4b2579b6b3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -10,13 +10,13 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp" #define DEBUG_LOG 0 @@ -72,7 +72,6 @@ __global__ void #endif // end of if (defined(__gfx9__)) } -#if 0 template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); GridwiseGemm::template Run_2Lds( karg.p_sorted_token_ids, karg.p_sorted_expert_ids, karg.p_max_token_id, - karg.p_a_grid, - karg.p_a_scale_grid, - karg.p_b_grid, - karg.p_b_scale_grid, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - p_shared, - p_shared1, + p_shared_0, + p_shared_1, karg, karg.a_element_op, karg.b_element_op, @@ -111,7 +110,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = karg; #endif // end of if (defined(__gfx9__)) } -#endif template {}); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); - return transform_tensor_descriptor( + constexpr auto permuted_desc = transform_tensor_descriptor( TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + permuted_desc, make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), make_unmerge_transform(make_tuple(Number{}, Number{}, @@ -398,12 +404,29 @@ struct GridwiseMoeGemmMXBNS // not pad M or K const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), make_pass_through_transform(M)), make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - return a_grid_desc_ak0_m_ak1; + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(M, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(M), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_grid_desc; } } @@ -487,12 +510,29 @@ struct GridwiseMoeGemmMXBNS // not pad N or K const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)), make_pass_through_transform(N)), make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - return b_grid_desc_bk0_n_bk1; + const auto b_grid_desc_permuted = transform_tensor_descriptor( + b_grid_desc_bk0_n_bk1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(N, BK0Number)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto b_grid_desc = transform_tensor_descriptor( + b_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)), + make_pass_through_transform(N), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc; } } @@ -810,9 +850,10 @@ struct GridwiseMoeGemmMXBNS // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // contiguous in LDS return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + make_tuple(Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); } // xor tensor transformation request more unnecessary vgpr usage, would cause register spill // in some cases. @@ -927,9 +968,10 @@ struct GridwiseMoeGemmMXBNS // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // contiguous in lds return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + make_tuple(BK1Number, Number{}, I1)); } else if constexpr(is_same::value) { @@ -1492,12 +1534,9 @@ struct GridwiseMoeGemmMXBNS // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy - auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + // A matrix blockwise direct to LDS copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad< ThisThreadBlock, - AElementwiseOperation, - ck::tensor_operation::element_wise::PassThrough, - InMemoryDataOperationEnum::Set, Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, @@ -1506,55 +1545,34 @@ struct GridwiseMoeGemmMXBNS decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, ABlockTransferSrcVectorDim, 2, ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, IndexType, - 1, - BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); + 1>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + gather_offsets); // B matrix blockwise copy auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + make_multi_index(0, 0, 0)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1790,7 +1808,6 @@ struct GridwiseMoeGemmMXBNS m0 * M2 * M1 * M3 * M4 * M5 + m1 * M2 * M3 * M4 * M5 + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; - if constexpr(MulRoutedWeight) { topk_weights = @@ -2131,7 +2148,6 @@ struct GridwiseMoeGemmMXBNS } } -#if 0 template @@ -2144,8 +2160,8 @@ struct GridwiseMoeGemmMXBNS const BScaleDataType* p_b_scale_grid, DsGridPointer& p_ds_grid, CDataType* p_c_grid, - void* p_shared, - void* p_shared1, + void* p_shared_0, + void* p_shared_1, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -2183,8 +2199,8 @@ struct GridwiseMoeGemmMXBNS const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); - // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; if(expert_block_id * MPerBlock >= max_token_id) return; @@ -2252,112 +2268,100 @@ struct GridwiseMoeGemmMXBNS const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( - problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + // Gride buffer creation const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - -#if 1 - printf("blkx: %u, blky: %u, tidx: %u, a_grid_size: %ld\n", - blockIdx.x, - blockIdx.y, - threadIdx.x, - a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - -#endif - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); // B matrix in LDS memory, dst of blockwise copy - // dummy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy - auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + + // A matrix blockwise direct to LDS copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad< ThisThreadBlock, - AElementwiseOperation, - ck::tensor_operation::element_wise::PassThrough, - InMemoryDataOperationEnum::Set, Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ADataType, - LDSTypeA, + ADataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, ABlockTransferSrcVectorDim, 2, ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, IndexType, - 1, - BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); - - // Thread-wise copy - // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack - auto b_block_buf_ping = make_static_buffer( - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - auto b_block_buf_pong = make_static_buffer( - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + 1>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + gather_offsets); + // B matrix blockwise copy auto b_blockwise_copy = - ThreadwiseTensorSliceTransfer_v2{}, - I1, - Number{}, - Number{}, - Number{}>, - Sequence<1, 2, 0, 3, 4>, - 4, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_bpreshuffled, - make_multi_index(n_block_data_idx_on_grid, - get_warp_local_1d_id() % NWave, - 0, - 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); // LDS allocation for A and B: be careful of alignment - // Cast after lds + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + auto a_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); // Blockwise GEMM pipeline static_assert(std::is_default_constructible_v); @@ -2429,22 +2433,25 @@ struct GridwiseMoeGemmMXBNS const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( p_b_grid_up + expert_id * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType, - decltype(b_grid_desc_bpreshuffled), + decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), - Sequence{}, I1, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, - 3, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_grid_desc_bpreshuffled, - make_multi_index(n_block_data_idx_on_grid, - get_warp_local_1d_id() % NWave, - 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; const auto b_scale_grid_buf_up = make_dynamic_buffer( p_b_scale_grid_up + expert_id * expert_scale_stride, @@ -2472,7 +2479,7 @@ struct GridwiseMoeGemmMXBNS a_grid_buf, a_block_bufs, a_block_slice_copy_step, - b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_blockwise_copy, b_blockwise_copy_up, @@ -2495,23 +2502,23 @@ struct GridwiseMoeGemmMXBNS else { blockwise_gemm_pipeline.template Run( - a_grid_desc_ak0_m_ak1, + a_grid_desc_ak0_m_ak1, // A a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, a_block_bufs, a_block_slice_copy_step, - b_grid_desc_bpreshuffled, + b_grid_desc_bk0_n_bk1, // B b_block_desc_bk0_n_bk1, b_blockwise_copy, b_grid_buf, b_block_bufs, b_block_slice_copy_step, - c_thread_buf, - a_scale_grid_desc_am_ak, + c_thread_buf, // C + a_scale_grid_desc_am_ak, // A scale a_scale_thread_copy, a_scale_grid_buf, - b_scale_grid_desc_bn_ak, + b_scale_grid_desc_bn_ak, // B scale b_scale_thread_copy, b_scale_grid_buf, num_k_block_main_loop); @@ -2522,89 +2529,102 @@ struct GridwiseMoeGemmMXBNS static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); // mul scales - static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4); + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); + static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; - const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; vector_type topk_weights; // for gemm2 only - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(MulRoutedWeight) - { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0 / MXdlPack, - n0 / NXdlPack, - m0 % MXdlPack, - n0 % NXdlPack, - m2 * M4 + m4)); - constexpr auto cidx = Number{}; - - if constexpr(IsInputGemm) // gu fusion - { - if constexpr(ActivationOperation == Activation::silu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - else if(ActivationOperation == Activation::gelu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - } - else - { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack + static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + + m0 * M2 * M1 * M3 * M4 * M5 + + m1 * M2 * M3 * M4 * M5 + + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = - topk_weights.AsType()[m4] * c_thread_buf_fp32[cidx]; + topk_weights = + *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); } - } + static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == + Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m5] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); }); }); }); @@ -2614,28 +2634,33 @@ struct GridwiseMoeGemmMXBNS GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), + static_cast(p_shared_0), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple(make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per - // shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -2647,8 +2672,8 @@ struct GridwiseMoeGemmMXBNS const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), make_tuple(Sequence<0>{})); const auto m_thread_data_on_block_idx = @@ -2657,8 +2682,8 @@ struct GridwiseMoeGemmMXBNS const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0>{})); const auto n_thread_data_on_block_idx = @@ -2666,36 +2691,39 @@ struct GridwiseMoeGemmMXBNS make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; using EDataType = CDataType; @@ -2716,16 +2744,18 @@ struct GridwiseMoeGemmMXBNS // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2746,51 +2776,63 @@ struct GridwiseMoeGemmMXBNS const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence