From 8a5bb9f34bbaaf89b087f00b3a69b603c00a3a95 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Sat, 8 Feb 2025 09:52:10 +0000 Subject: [PATCH] add files , build and run ok --- .../65_gemm_multiply_multiply/CMakeLists.txt | 3 +- .../{moe_gemm_fp16.cpp => moe_gemm1.cpp} | 0 .../65_gemm_multiply_multiply/moe_gemm2.cpp | 399 ++++++++++++++++++ ...hread_group_tensor_slice_transfer_v7r3.hpp | 3 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 50 +-- .../cpu/reference_moe_gemm2.hpp | 173 ++++++++ 6 files changed, 598 insertions(+), 30 deletions(-) rename example/65_gemm_multiply_multiply/{moe_gemm_fp16.cpp => moe_gemm1.cpp} (100%) create mode 100644 example/65_gemm_multiply_multiply/moe_gemm2.cpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index e4e10d5b76..8b03a808ec 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -5,4 +5,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_m # target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) -add_example_executable(example_moe_gemm_fp16 moe_gemm_fp16.cpp) +add_example_executable(example_moe_gemm1 moe_gemm1.cpp) +add_example_executable(example_moe_gemm2 moe_gemm2.cpp) diff --git a/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp b/example/65_gemm_multiply_multiply/moe_gemm1.cpp similarity index 100% rename from example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp rename to example/65_gemm_multiply_multiply/moe_gemm1.cpp diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2.cpp new file mode 100644 index 0000000000..008e2681bb --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm2.cpp @@ -0,0 +1,399 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +// using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(EDataType& e, + const float& c, + const float& d0, + const float& d1) const + { + // const float x0_f = c * d0 * d1; + const float x0_f = c; + // printf("epi %f\n", c); + e = ck::type_convert(x0_f); + } + + // template <> + // __host__ __device__ constexpr void operator()(BF16& e, + // const float& c, + // const float& d0, + // const float& d1) const + // { + // const float x0_f = c; + // // const float x0_f = c * d0 * d1; + + // e = ck::type_convert(x0_f); + // } +}; + + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); +static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + // kernel 1: 256->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, EDataType>; + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + 256, MPerBlock, 128, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + 32, 32, + // mn_xdlperwave + MXDLPerWave, 1, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + 1, 1, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType>; + // kernel 2: 128->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; + +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + +// tokens = 1 +// topk = 1 +// experts = 8 +// per expert: + // GEMM shape + ck::index_t N = 128; + ck::index_t K = 1024; + ck::index_t experts = 1; + ck::index_t sorted_tile_num = 1; + ck::index_t sorted_tile_size = MPerBlock; + ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; + ck::index_t tokens = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 5: N, K\n"); + exit(0); + } + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + // const ck::index_t experts = 8; + Tensor expert_ids(HostTensorDescriptor({experts}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1})); + for (int i = 0; i < sorted_tile_num; i++) { + expert_ids.mData[i] = i; + } + int token_per_tile = tokens / sorted_tile_num; + int tokenid = 0; + // sorted_token_ids.mData[0] = 0; + for (int i = 0; i < SORTED_SIZE; i++) { + int tile_off = i % sorted_tile_size; + if(tile_off < token_per_tile) + sorted_token_ids.mData[i] = tokenid++; + else + sorted_token_ids.mData[i] = tokens; + } + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); + // Tensor b0_e_n_k(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{})); + // Tensor b0_preshuffled( + // f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor d0_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D0Layout{})); + Tensor d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl; + std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_t_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_t_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + a0_m_k.savetxt("a.txt"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + a0_device_buf.ToDevice(a0_m_k.mData.data()); + d0_device_buf.ToDevice(d0_t_n.mData.data()); + d1_device_buf.ToDevice(d1_t_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + SORTED_SIZE, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if (time_kernel) { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K; + std::size_t num_btype = + sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1}); + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument( + sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < SORTED_SIZE; ++m) + { + + const int t = sorted_token_ids(m); + for(int n = 0; n < N; ++n) + { + cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + e_t_n_device_result.savetxt("out.txt"); + e_t_n_host_result.savetxt("ref.txt"); + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp index 46d0c6ac2e..1feea921a6 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp @@ -48,6 +48,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 static constexpr index_t nDim = remove_cvref_t>::GetNumOfDimension(); + static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}); // Dirty HACK FELIX, TODO fix static constexpr index_t nSrc = remove_cvref_t::Size(); static constexpr index_t nDst = remove_cvref_t::Size(); @@ -101,7 +102,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(ThreadGroup::GetThreadId())); + make_multi_index(ThreadGroup::GetThreadId() % mod_num)); const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 5af29974f5..4a089c1ee6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -9,7 +9,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -1109,12 +1109,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle { ignore = b_element_op; const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + problem.NumTokens, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); // printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(), // problem.MBlock, problem.NBlock, MPerBlock, NPerBlock); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = @@ -1125,19 +1125,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); - // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); - constexpr auto AKThreads = AK0Threads * AK1Threads; - constexpr auto AMRepeats = MPerBlock / AMThreads; - // static_assert(MLoadRepeats == 1, "only support 1 line per thread now!"); - const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; - StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; - static_for<0, AMRepeats, 1>{}([&](auto m0) { - gather_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.K; - // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); - }); const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); @@ -1153,10 +1140,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle // if(threadIdx.x==0) // printf("tid %d eid %d expert_stride %d bufsize %d\n", // threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1166,7 +1149,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1_mod8( a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); + ck::tensor_operation::element_wise::PassThrough{}); // Thread-wise copy // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -1406,10 +1387,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const auto e_grid_desc_mblock_mperblock_nblock_nperblock = c_grid_desc_mblock_mperblock_nblock_nperblock; - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = MPerBlock / EMThreads; + static_assert(EMRepeats == 1, "only support 1 line per thread now!"); + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / EMThreads * EMRepeats; + StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[token_pos]; + static_for<0, EMRepeats, 1>{}([&](auto m0) { + scatter_offsets(m0) = (p_sorted_token_ids[token_pos + m0] & 0xffffff) * problem.N; + // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); + }); + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1423,7 +1414,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferCluster, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, @@ -1439,9 +1430,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle {c_ds_desc_refs, idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), c_element_op}; - + // if(threadIdx.x== 0) + // printf("offset %d size %d\n", scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid + scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() - scatter_offsets(I0)); // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp new file mode 100644 index 0000000000..a5c824fcc9 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMoeGemm2 : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const index_t sorted_tile_size, + const Tensor& a_m_k, + const Tensor& b_e_n_k, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + sorted_tile_size_{sorted_tile_size}, + a_m_k_{a_m_k}, + b_e_n_k_{b_e_n_k}, + c_t_n_{c_t_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& expert_ids_; + const Tensor& sorted_token_ids_; + const Tensor& a_m_k_; + const Tensor& b_e_n_k_; + Tensor& c_t_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t sorted_tile_size_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMoeGemm2::Argument; + + float Run(const Argument& arg) + { + arg.c_t_n_.SetZero(); + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc{0}; + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + const int t = arg.sorted_token_ids_(m); + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); + const int token_cnt = arg.a_m_k_.mDesc.GetLengths()[0]; + if(t < token_cnt) { + for(int k = 0; k < K; ++k) + { + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_e_n_k_(e, n, k)); + } + else + { + arg.b_element_op_(v_b, arg.b_e_n_k_(e, n, k)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + } + + CDataType v_c{0}; + + arg.c_element_op_(v_c, v_acc); + + arg.c_t_n_(t, n) += v_c; + }; + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_t_n_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const index_t sorted_tile_size, + const Tensor& a_m_k, + const Tensor& b_e_n_k, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, c_t_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMoeGemm2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck