From 8a5bb9f34bbaaf89b087f00b3a69b603c00a3a95 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Sat, 8 Feb 2025 09:52:10 +0000 Subject: [PATCH 1/5] add files , build and run ok --- .../65_gemm_multiply_multiply/CMakeLists.txt | 3 +- .../{moe_gemm_fp16.cpp => moe_gemm1.cpp} | 0 .../65_gemm_multiply_multiply/moe_gemm2.cpp | 399 ++++++++++++++++++ ...hread_group_tensor_slice_transfer_v7r3.hpp | 3 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 50 +-- .../cpu/reference_moe_gemm2.hpp | 173 ++++++++ 6 files changed, 598 insertions(+), 30 deletions(-) rename example/65_gemm_multiply_multiply/{moe_gemm_fp16.cpp => moe_gemm1.cpp} (100%) create mode 100644 example/65_gemm_multiply_multiply/moe_gemm2.cpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index e4e10d5b76..8b03a808ec 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -5,4 +5,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_m # target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) -add_example_executable(example_moe_gemm_fp16 moe_gemm_fp16.cpp) +add_example_executable(example_moe_gemm1 moe_gemm1.cpp) +add_example_executable(example_moe_gemm2 moe_gemm2.cpp) diff --git a/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp b/example/65_gemm_multiply_multiply/moe_gemm1.cpp similarity index 100% rename from example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp rename to example/65_gemm_multiply_multiply/moe_gemm1.cpp diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2.cpp new file mode 100644 index 0000000000..008e2681bb --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm2.cpp @@ -0,0 +1,399 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +// using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(EDataType& e, + const float& c, + const float& d0, + const float& d1) const + { + // const float x0_f = c * d0 * d1; + const float x0_f = c; + // printf("epi %f\n", c); + e = ck::type_convert(x0_f); + } + + // template <> + // __host__ __device__ constexpr void operator()(BF16& e, + // const float& c, + // const float& d0, + // const float& d1) const + // { + // const float x0_f = c; + // // const float x0_f = c * d0 * d1; + + // e = ck::type_convert(x0_f); + // } +}; + + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); +static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + // kernel 1: 256->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, EDataType>; + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + 256, MPerBlock, 128, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + 32, 32, + // mn_xdlperwave + MXDLPerWave, 1, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + 1, 1, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType>; + // kernel 2: 128->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; + +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + +// tokens = 1 +// topk = 1 +// experts = 8 +// per expert: + // GEMM shape + ck::index_t N = 128; + ck::index_t K = 1024; + ck::index_t experts = 1; + ck::index_t sorted_tile_num = 1; + ck::index_t sorted_tile_size = MPerBlock; + ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; + ck::index_t tokens = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 5: N, K\n"); + exit(0); + } + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + // const ck::index_t experts = 8; + Tensor expert_ids(HostTensorDescriptor({experts}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1})); + for (int i = 0; i < sorted_tile_num; i++) { + expert_ids.mData[i] = i; + } + int token_per_tile = tokens / sorted_tile_num; + int tokenid = 0; + // sorted_token_ids.mData[0] = 0; + for (int i = 0; i < SORTED_SIZE; i++) { + int tile_off = i % sorted_tile_size; + if(tile_off < token_per_tile) + sorted_token_ids.mData[i] = tokenid++; + else + sorted_token_ids.mData[i] = tokens; + } + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); + // Tensor b0_e_n_k(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{})); + // Tensor b0_preshuffled( + // f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor d0_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D0Layout{})); + Tensor d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl; + std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_t_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_t_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + a0_m_k.savetxt("a.txt"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + a0_device_buf.ToDevice(a0_m_k.mData.data()); + d0_device_buf.ToDevice(d0_t_n.mData.data()); + d1_device_buf.ToDevice(d1_t_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + SORTED_SIZE, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if (time_kernel) { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K; + std::size_t num_btype = + sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1}); + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument( + sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < SORTED_SIZE; ++m) + { + + const int t = sorted_token_ids(m); + for(int n = 0; n < N; ++n) + { + cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + e_t_n_device_result.savetxt("out.txt"); + e_t_n_host_result.savetxt("ref.txt"); + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp index 46d0c6ac2e..1feea921a6 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp @@ -48,6 +48,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 static constexpr index_t nDim = remove_cvref_t>::GetNumOfDimension(); + static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}); // Dirty HACK FELIX, TODO fix static constexpr index_t nSrc = remove_cvref_t::Size(); static constexpr index_t nDst = remove_cvref_t::Size(); @@ -101,7 +102,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(ThreadGroup::GetThreadId())); + make_multi_index(ThreadGroup::GetThreadId() % mod_num)); const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 5af29974f5..4a089c1ee6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -9,7 +9,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -1109,12 +1109,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle { ignore = b_element_op; const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + problem.NumTokens, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); // printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(), // problem.MBlock, problem.NBlock, MPerBlock, NPerBlock); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = @@ -1125,19 +1125,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); - // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); - constexpr auto AKThreads = AK0Threads * AK1Threads; - constexpr auto AMRepeats = MPerBlock / AMThreads; - // static_assert(MLoadRepeats == 1, "only support 1 line per thread now!"); - const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; - StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; - static_for<0, AMRepeats, 1>{}([&](auto m0) { - gather_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.K; - // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); - }); const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); @@ -1153,10 +1140,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle // if(threadIdx.x==0) // printf("tid %d eid %d expert_stride %d bufsize %d\n", // threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1166,7 +1149,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1_mod8( a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); + ck::tensor_operation::element_wise::PassThrough{}); // Thread-wise copy // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -1406,10 +1387,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const auto e_grid_desc_mblock_mperblock_nblock_nperblock = c_grid_desc_mblock_mperblock_nblock_nperblock; - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = MPerBlock / EMThreads; + static_assert(EMRepeats == 1, "only support 1 line per thread now!"); + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / EMThreads * EMRepeats; + StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[token_pos]; + static_for<0, EMRepeats, 1>{}([&](auto m0) { + scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N; + // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); + }); + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1423,7 +1414,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferCluster, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, @@ -1439,9 +1430,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle {c_ds_desc_refs, idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), c_element_op}; - + // if(threadIdx.x== 0) + // printf("offset %d size %d\n", scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid + scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() - scatter_offsets(I0)); // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp new file mode 100644 index 0000000000..a5c824fcc9 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMoeGemm2 : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const index_t sorted_tile_size, + const Tensor& a_m_k, + const Tensor& b_e_n_k, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + sorted_tile_size_{sorted_tile_size}, + a_m_k_{a_m_k}, + b_e_n_k_{b_e_n_k}, + c_t_n_{c_t_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& expert_ids_; + const Tensor& sorted_token_ids_; + const Tensor& a_m_k_; + const Tensor& b_e_n_k_; + Tensor& c_t_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t sorted_tile_size_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMoeGemm2::Argument; + + float Run(const Argument& arg) + { + arg.c_t_n_.SetZero(); + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc{0}; + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + const int t = arg.sorted_token_ids_(m); + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); + const int token_cnt = arg.a_m_k_.mDesc.GetLengths()[0]; + if(t < token_cnt) { + for(int k = 0; k < K; ++k) + { + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_e_n_k_(e, n, k)); + } + else + { + arg.b_element_op_(v_b, arg.b_e_n_k_(e, n, k)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + } + + CDataType v_c{0}; + + arg.c_element_op_(v_c, v_acc); + + arg.c_t_n_(t, n) += v_c; + }; + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_t_n_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const index_t sorted_tile_size, + const Tensor& a_m_k, + const Tensor& b_e_n_k, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, c_t_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMoeGemm2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck From 7ba5bff4c2736bf86837156ee99298b9fe17fc90 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Sat, 8 Feb 2025 12:31:25 +0000 Subject: [PATCH 2/5] one tile ok --- .../65_gemm_multiply_multiply/moe_gemm2.cpp | 20 ++++++++++--------- ...hread_group_tensor_slice_transfer_v7r3.hpp | 13 ++++++------ ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 10 ++++------ .../cpu/reference_moe_gemm2.hpp | 18 ++++++++--------- script/cmake-ck-dev.sh | 2 +- 5 files changed, 31 insertions(+), 32 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2.cpp index 008e2681bb..5029f2ea95 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2.cpp @@ -119,6 +119,7 @@ using CDEElementOp = MultiplyMultiply; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); @@ -142,7 +143,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu // ak1, bk1 AK1, BK1, // mn_perxdl - 32, 32, + MNPerXDL, MNPerXDL, // mn_xdlperwave MXDLPerWave, 1, // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra @@ -173,11 +174,11 @@ int main(int argc, char* argv[]) // GEMM shape ck::index_t N = 128; ck::index_t K = 1024; - ck::index_t experts = 1; - ck::index_t sorted_tile_num = 1; + ck::index_t experts = 8; + ck::index_t sorted_tile_num = 2; ck::index_t sorted_tile_size = MPerBlock; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; - ck::index_t tokens = 1; + ck::index_t tokens = 32; if(argc == 1) { @@ -251,7 +252,7 @@ int main(int argc, char* argv[]) Tensor d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); - + e_t_n_device_result.SetZero(); std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl; @@ -358,8 +359,7 @@ int main(int argc, char* argv[]) { invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1}); - e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - + // e_device_buf.FromDevice(e_t_n_device_result.mData.data()); Tensor c_t_n({tokens, N}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2{}); + const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId() % mod_num)); const auto dst_thread_slice_origins = generate_tuple( - [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + [&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; }, Number{}); threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 4a089c1ee6..284e8dec01 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -1115,8 +1115,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.NumTokens, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - // printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(), - // problem.MBlock, problem.NBlock, MPerBlock, NPerBlock); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); @@ -1393,14 +1391,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); constexpr auto EMRepeats = MPerBlock / EMThreads; + constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); static_assert(EMRepeats == 1, "only support 1 line per thread now!"); - const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / EMThreads * EMRepeats; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats; StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[token_pos]; static_for<0, EMRepeats, 1>{}([&](auto m0) { scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N; - // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); + printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); }); + // printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1433,7 +1433,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_tuple(make_multi_index(0, 0, block_n_id, 0)), c_element_op}; // if(threadIdx.x== 0) - // printf("offset %d size %d\n", scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( p_c_grid + scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() - scatter_offsets(I0)); // space filling curve for threadwise C in VGPR @@ -1461,7 +1460,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - // printf("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee\n"); static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS block_sync_lds(); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index a5c824fcc9..8633f061b1 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -71,13 +71,13 @@ struct ReferenceMoeGemm2 : public device::BaseOperator arg.c_t_n_.SetZero(); auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_m_k_.mDesc.GetLengths()[1]; - AccDataType v_acc{0}; ComputeTypeA v_a{0}; ComputeTypeB v_b{0}; const int t = arg.sorted_token_ids_(m); const int e = arg.expert_ids_(m / arg.sorted_tile_size_); - const int token_cnt = arg.a_m_k_.mDesc.GetLengths()[0]; + const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0]; + if(t < token_cnt) { for(int k = 0; k < K; ++k) { @@ -105,17 +105,17 @@ struct ReferenceMoeGemm2 : public device::BaseOperator v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); } + CDataType v_c{0}; + + arg.c_element_op_(v_c, v_acc); + + arg.c_t_n_(t, n) += v_c; } - CDataType v_c{0}; - - arg.c_element_op_(v_c, v_acc); - - arg.c_t_n_(t, n) += v_c; }; make_ParallelTensorFunctor( - f_mk_kn_mn, arg.c_t_n_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])( - std::thread::hardware_concurrency()); + f_mk_kn_mn, arg.a_m_k_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])( + 1); return 0; } diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 7555a582ea..12c739e324 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -17,7 +17,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 --save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O1 -g --save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \ From 123014559088e77f39c10410455bd78b72c7ce8e Mon Sep 17 00:00:00 2001 From: coderfeli Date: Sun, 9 Feb 2025 09:02:32 +0000 Subject: [PATCH 3/5] gemm2 result ok --- example/65_gemm_multiply_multiply/moe_gemm2.cpp | 14 ++++++++------ ...emm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp | 4 ++-- ...e_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 4 ++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2.cpp index 5029f2ea95..edaf763241 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2.cpp @@ -33,8 +33,8 @@ using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; -using A0DataType = F16; -using B0DataType = F16; +using A0DataType = F8; +using B0DataType = F8; using AccDataType = F32; using CShuffleDataType = F32; using D0DataType = F32; @@ -172,10 +172,10 @@ int main(int argc, char* argv[]) // experts = 8 // per expert: // GEMM shape - ck::index_t N = 128; - ck::index_t K = 1024; + ck::index_t N = 6144; + ck::index_t K = 8192; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 2; + ck::index_t sorted_tile_num = 8; ck::index_t sorted_tile_size = MPerBlock; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t tokens = 32; @@ -341,6 +341,7 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } if (time_kernel) { + // not result correct here because output buf not setzero float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K; @@ -357,9 +358,10 @@ int main(int argc, char* argv[]) if(do_verification) { + //gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1}); - // e_device_buf.FromDevice(e_t_n_device_result.mData.data()); Tensor c_t_n({tokens, N}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2; Run(kernel); @@ -289,7 +289,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< GridwiseGemm, true, - InMemoryDataOperationEnum::Set, + InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, TailNumber::Even>; Run(kernel); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 284e8dec01..911ae49f33 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -1170,7 +1170,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), + make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1397,7 +1397,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[token_pos]; static_for<0, EMRepeats, 1>{}([&](auto m0) { scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N; - printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); + // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); }); // printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); From e21f36fc242b5de5353a299b557ae1f06bccbc41 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Sun, 9 Feb 2025 13:44:42 +0000 Subject: [PATCH 4/5] moegemm2 ok --- .../65_gemm_multiply_multiply/moe_gemm2.cpp | 3 +- ...hread_group_tensor_slice_transfer_v7r3.hpp | 13 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 7 +- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 696 ++++++++++++++++++ 4 files changed, 711 insertions(+), 8 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2.cpp index edaf763241..654542423c 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2.cpp @@ -122,6 +122,7 @@ static constexpr ck::index_t MPerBlock = 32; static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint +static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); @@ -154,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 1, 1, S<1, 32, 1, 8>, S, + CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp index fb72a98b4d..1bd25994a6 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp @@ -7,7 +7,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp" #include "ck/utility/is_detected.hpp" namespace ck { @@ -42,6 +42,7 @@ template struct ThreadGroupTensorSliceTransfer_v7r3 { @@ -55,18 +56,21 @@ struct ThreadGroupTensorSliceTransfer_v7r3 using Index = MultiIndex; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + static constexpr index_t scatter_num = thread_slice_lengths.At(Number{}); __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3( const SrcDescs& src_descs, const StaticallyIndexedArray& src_block_slice_origins, const DstDescs& dst_descs, const StaticallyIndexedArray& dst_block_slice_origins, - const ElementwiseOperation& element_op) + const ElementwiseOperation& element_op, + const StaticallyIndexedArray &scatter_offsets) : threadwise_transfer_(src_descs, StaticallyIndexedArray{}, dst_descs, StaticallyIndexedArray{}, - element_op) + element_op, + scatter_offsets) { static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && @@ -197,7 +201,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v7r3; ThreadwiseTransfer threadwise_transfer_; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 911ae49f33..b78433bb30 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -1392,7 +1392,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); constexpr auto EMRepeats = MPerBlock / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - static_assert(EMRepeats == 1, "only support 1 line per thread now!"); + // static_assert(EMRepeats == 1, "only support 1 line per thread now!"); const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats; StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[token_pos]; static_for<0, EMRepeats, 1>{}([&](auto m0) { @@ -1431,10 +1431,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + c_element_op, + scatter_offsets}; // if(threadIdx.x== 0) auto c_grid_buf = make_dynamic_buffer( - p_c_grid + scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() - scatter_offsets(I0)); + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp new file mode 100644 index 0000000000..480fbc5ff0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -0,0 +1,696 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags, // Sequence + index_t ScatterDim = 1, + index_t NumThreadScratch = 1> +struct ThreadwiseTensorSliceTransfer_v7r3_scatter +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto SrcScalarPerVector = SrcScalarPerVectors{}[I0]; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + static constexpr index_t scatter_num = SliceLengths{}.At(Number{}); + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + + using DstSpaceFillingCurve = SpaceFillingCurve, + false>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r3_scatter( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op, + const StaticallyIndexedArray &scatter_offsets) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op), + scatter_offsets_(scatter_offsets) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + // printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x, dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1], dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3], dst_coords_(i).GetOffset()); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + if constexpr(SrcScalarPerVectors{}[i] == 1) + { + auto data_types = SrcDatas{}; + using DataType = remove_cvref_t; + const auto tmp = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + + static_for<0, SrcScalarPerVector, 1>{}( + [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); + } + else + { + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + } + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + +#if 1 + template + __device__ void OOBCheck(Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess]; + auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + using elm_vector_t = typename remove_cvref_t::type; + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + }); + } +#endif + + template + __device__ void + TransposeFromElmToDst(Number thread_scratch_id = Number{}) + { + using DstData = remove_cvref_t; + + using ElmThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + ElmThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_[thread_scratch_id]); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from + // dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number{}); + const auto scatter_offset = scatter_offsets_(Number{}); + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); + const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();//hack felix, todo use coord + // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + // dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + // if(threadIdx.x==0) + // printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(), scatter_offset ); + dst_bufs(i).template Update( + dst_offset, + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + // if(1) { + // static_for<0, DstScalarPerVector, 1>{}([&](auto idx) { + // using DstData = remove_cvref_t>; + // using print_vec_t = typename vector_type::type; + // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid, + // type_convert(dst_vectors[i].template AsType()[idx])); + // }); + // } + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + auto forward_step_scatter = [&]() constexpr + { + Index step_; + + static_for<0, nDim, 1>{}([&](auto i) { + step_(i) = i.value != ScatterDim ? forward_step[i] : 0; + + // if(threadIdx.x==0) + // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim); + }); + + return step_; + } + (); + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step_scatter)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(src_num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(dst_num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + auto reset_step_scatter = [&]() constexpr + { + Index step_; + static_for<0, nDim, 1>{}([&](auto i) { + step_(i) = i.value != ScatterDim ? reset_step[Number{}] : 0; + + // if(threadIdx.x==0) + // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim); + }); + + return step_; + } + (); + return reset_step_scatter; + } + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); + + using ElmVectorTuple = StaticallyIndexedArray; + using DstVectorTuple = StaticallyIndexedArray; + + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; + + using OOBVectorTuple = StaticallyIndexedArray; + StaticallyIndexedArray oob_vectors_tuple_; + + StaticallyIndexedArray scatter_offsets_; + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck From 2e53f9725b4aed93909eada64f0066cb7cc7275f Mon Sep 17 00:00:00 2001 From: coderfeli Date: Mon, 10 Feb 2025 01:26:08 +0000 Subject: [PATCH 5/5] skip empty expert --- .../gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index b78433bb30..b51231c0eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -1126,7 +1126,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); - + + const index_t t0 = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + if(t0 >= problem.NumTokens) + return; // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);