From fe8bb251da64e173c4bbb366098a352cf1d258b0 Mon Sep 17 00:00:00 2001 From: OscarXu Date: Mon, 12 May 2025 21:08:42 +0800 Subject: [PATCH] gemm1 up-only pass. GU WIP --- .../moe_gemm1_xdl_fp8.cpp | 2 +- .../moe_gemm1_xdl_fp8_blockscale.cpp | 14 +- ...oup_tensor_slice_transfer_v7r3_scatter.hpp | 25 +++- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 14 +- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 130 +++++++++++++++++- .../cpu/reference_moe_gemm1_blockscale.hpp | 4 +- 6 files changed, 171 insertions(+), 18 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index a05234ad3c..fa8f95db25 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -269,7 +269,7 @@ int main(int argc, char* argv[]) // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = i / (valid_tile_num / experts); + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); } int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp index d5942a506d..155c052b4e 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -268,7 +268,7 @@ int main(int argc, char* argv[]) constexpr auto StrideDs = std::array{0}; ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; - ck::index_t Scale_Stride_B = (N * 2 + Scale_Block_N - 1) / Scale_Block_N; + ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2; ck::index_t KBatch = 1; @@ -279,7 +279,7 @@ int main(int argc, char* argv[]) // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = i / (valid_tile_num / experts); + expert_ids.mData[i] = (i / valid_tile_num) / experts; } int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; @@ -303,7 +303,7 @@ int main(int argc, char* argv[]) {Scale_Stride_AM, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor b1_e_n_k(HostTensorDescriptor( - {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N * 2 + Scale_Block_N - 1) / Scale_Block_N}, + {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N * 2}, {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); @@ -381,8 +381,8 @@ int main(int argc, char* argv[]) DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); // a0_t_k.savetxt("a.txt"); - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); // d2_e_n.savetxt("d2_e_n.txt", "int"); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); @@ -503,8 +503,8 @@ int main(int argc, char* argv[]) expert_ids, max_token_id, MPerBlock, - a0_t_k, - b0_e_n_k, + a_t_k, + b_e_n_k, d2_e_n, c_t_k_n, PassThrough{}, diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index 1cfe72981b..b815fcf83c 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -131,6 +131,18 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter } } + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + template __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, @@ -170,7 +182,18 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs, scatter_offsets); + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs, + StaticallyIndexedArray& scatter_offsets, StaticallyIndexedArray& scatter_weights) { RunRead(src_descs, src_bufs, scatter_weights); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 5240d29f1e..f5e4b759f0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -1206,7 +1206,7 @@ struct GridwiseMoeGemmBlockScale math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm ? 2 : 1), ScaleBlockN), + make_tuple(math::integer_divide_ceil(problem.N , ScaleBlockN), math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); @@ -1267,7 +1267,7 @@ struct GridwiseMoeGemmBlockScale }); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); const index_t expert_scale_stride = - __builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N * (IsInputGemm ? 2 : 1), ScaleBlockN) * + __builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockK)); // N0, K0, Blocksize*KPack @@ -1958,7 +1958,7 @@ struct GridwiseMoeGemmBlockScale math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm ? 2 : 1), ScaleBlockN), + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = @@ -2019,7 +2019,7 @@ struct GridwiseMoeGemmBlockScale }); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); const index_t expert_scale_stride = - __builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N * (IsInputGemm ? 2 : 1), ScaleBlockN) * + __builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N , ScaleBlockN) * (IsInputGemm ? 2 : 1) * math::integer_divide_ceil(problem.K, ScaleBlockK)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -2158,8 +2158,8 @@ struct GridwiseMoeGemmBlockScale { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - scale_gather_offsets(m0) = - token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK); + scale_gather_offsets(m0) = static_cast(token_offset) * + math::integer_divide_ceil(problem.K, ScaleBlockK); }); // printf("blkid: %d, tid:%d, a_thread_offset: %d, scale_gather_offsets: %d\n", block_m_id, @@ -2222,7 +2222,7 @@ struct GridwiseMoeGemmBlockScale KPack / KGroup * (get_thread_local_1d_id() % warpSize))); const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride, + p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2 // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + template = false> @@ -413,7 +526,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { OOBCheck(thread_scratch_id); @@ -485,6 +598,21 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs, + StaticallyIndexedArray& scatter_offsets) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs, scatter_offsets); + } + + template = false> __device__ void Run(const SrcDescs& src_descs, const SrcBuffers& src_bufs, const DstDescs& dst_descs, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp index 367f2ea586..6abccc7bc6 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp @@ -163,7 +163,7 @@ struct ReferenceMoeGemm1BlockScale : public device::BaseOperator arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c_up, v_acc_up); - + #if 0 if constexpr(ActivationType == 1) { if constexpr(is_same_v) @@ -184,6 +184,8 @@ struct ReferenceMoeGemm1BlockScale : public device::BaseOperator tensor_operation::element_wise::Gelu{}(v_c, v_c); arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; } + #endif + arg.c_t_k_n_(t, topk_id, n) = v_c; } };