diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index e4e10d5b76..8b03a808ec 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -5,4 +5,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_m # target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) -add_example_executable(example_moe_gemm_fp16 moe_gemm_fp16.cpp) +add_example_executable(example_moe_gemm1 moe_gemm1.cpp) +add_example_executable(example_moe_gemm2 moe_gemm2.cpp) diff --git a/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp b/example/65_gemm_multiply_multiply/moe_gemm1.cpp similarity index 100% rename from example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp rename to example/65_gemm_multiply_multiply/moe_gemm1.cpp diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2.cpp new file mode 100644 index 0000000000..654542423c --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm2.cpp @@ -0,0 +1,404 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +// using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using B0DataType = F8; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(EDataType& e, + const float& c, + const float& d0, + const float& d1) const + { + // const float x0_f = c * d0 * d1; + const float x0_f = c; + // printf("epi %f\n", c); + e = ck::type_convert(x0_f); + } + + // template <> + // __host__ __device__ constexpr void operator()(BF16& e, + // const float& c, + // const float& d0, + // const float& d1) const + // { + // const float x0_f = c; + // // const float x0_f = c * d0 * d1; + + // e = ck::type_convert(x0_f); + // } +}; + + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); +static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint +static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + // kernel 1: 256->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, EDataType>; + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + 256, MPerBlock, 128, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + MNPerXDL, MNPerXDL, + // mn_xdlperwave + MXDLPerWave, 1, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType>; + // kernel 2: 128->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; + +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + +// tokens = 1 +// topk = 1 +// experts = 8 +// per expert: + // GEMM shape + ck::index_t N = 6144; + ck::index_t K = 8192; + ck::index_t experts = 8; + ck::index_t sorted_tile_num = 8; + ck::index_t sorted_tile_size = MPerBlock; + ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; + ck::index_t tokens = 32; + + if(argc == 1) + { + // use default case + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 5: N, K\n"); + exit(0); + } + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + // const ck::index_t experts = 8; + Tensor expert_ids(HostTensorDescriptor({experts}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({SORTED_SIZE}, {1})); + for (int i = 0; i < sorted_tile_num; i++) { + expert_ids.mData[i] = i; + } + int token_per_tile = tokens / sorted_tile_num; + int tokenid = 0; + // sorted_token_ids.mData[0] = 0; + for (int i = 0; i < SORTED_SIZE; i++) { + int tile_off = i % sorted_tile_size; + if(tile_off < token_per_tile) + sorted_token_ids.mData[i] = tokenid++; + else + sorted_token_ids.mData[i] = tokens; + } + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); + // Tensor b0_e_n_k(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{})); + // Tensor b0_preshuffled( + // f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor d0_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D0Layout{})); + Tensor d1_t_n(f_host_tensor_descriptor(tokens, N, StrideD, D1Layout{})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + e_t_n_device_result.SetZero(); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d1_t_n: " << d1_t_n.mDesc << std::endl; + std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_t_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_t_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + a0_m_k.savetxt("a.txt"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + a0_device_buf.ToDevice(a0_m_k.mData.data()); + d0_device_buf.ToDevice(d0_t_n.mData.data()); + d1_device_buf.ToDevice(d1_t_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + SORTED_SIZE, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if (time_kernel) { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * SORTED_SIZE * N * K; + std::size_t num_btype = + sizeof(A0DataType) * SORTED_SIZE * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * SORTED_SIZE * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + } + + if(do_verification) + { + //gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0 ,0,1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument( + sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int t = 0; t < tokens; ++t) + { + + // const int t = sorted_token_ids(m); + for(int n = 0; n < N; ++n) + { + cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + e_t_n_device_result.savetxt("out.txt"); + e_t_n_host_result.savetxt("ref.txt"); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp index 46d0c6ac2e..1bd25994a6 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp @@ -7,7 +7,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp" #include "ck/utility/is_detected.hpp" namespace ck { @@ -42,30 +42,35 @@ template struct ThreadGroupTensorSliceTransfer_v7r3 { static constexpr index_t nDim = remove_cvref_t>::GetNumOfDimension(); + static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}); // Dirty HACK FELIX, TODO fix static constexpr index_t nSrc = remove_cvref_t::Size(); static constexpr index_t nDst = remove_cvref_t::Size(); using Index = MultiIndex; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + static constexpr index_t scatter_num = thread_slice_lengths.At(Number{}); __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3( const SrcDescs& src_descs, const StaticallyIndexedArray& src_block_slice_origins, const DstDescs& dst_descs, const StaticallyIndexedArray& dst_block_slice_origins, - const ElementwiseOperation& element_op) + const ElementwiseOperation& element_op, + const StaticallyIndexedArray &scatter_offsets) : threadwise_transfer_(src_descs, StaticallyIndexedArray{}, dst_descs, StaticallyIndexedArray{}, - element_op) + element_op, + scatter_offsets) { static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && @@ -100,17 +105,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { - const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( make_multi_index(ThreadGroup::GetThreadId())); - - const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; - const auto src_thread_slice_origins = generate_tuple( - [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + [&](auto i) { return src_block_slice_origins[i] + src_thread_cluster_idx * thread_slice_lengths; }, Number{}); + const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId() % mod_num)); const auto dst_thread_slice_origins = generate_tuple( - [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + [&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; }, Number{}); threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); @@ -197,7 +201,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v7r3; ThreadwiseTransfer threadwise_transfer_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp index 958265e712..d10c334537 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp @@ -279,7 +279,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< GridwiseGemm, true, - InMemoryDataOperationEnum::Set, + InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, TailNumber::Odd>; Run(kernel); @@ -289,7 +289,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< GridwiseGemm, true, - InMemoryDataOperationEnum::Set, + InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, TailNumber::Even>; Run(kernel); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp new file mode 100644 index 0000000000..480fbc5ff0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -0,0 +1,696 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags, // Sequence + index_t ScatterDim = 1, + index_t NumThreadScratch = 1> +struct ThreadwiseTensorSliceTransfer_v7r3_scatter +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto SrcScalarPerVector = SrcScalarPerVectors{}[I0]; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + static constexpr index_t scatter_num = SliceLengths{}.At(Number{}); + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + + using DstSpaceFillingCurve = SpaceFillingCurve, + false>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r3_scatter( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op, + const StaticallyIndexedArray &scatter_offsets) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op), + scatter_offsets_(scatter_offsets) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + // printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x, dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1], dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3], dst_coords_(i).GetOffset()); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + if constexpr(SrcScalarPerVectors{}[i] == 1) + { + auto data_types = SrcDatas{}; + using DataType = remove_cvref_t; + const auto tmp = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + + static_for<0, SrcScalarPerVector, 1>{}( + [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); + } + else + { + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + } + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + +#if 1 + template + __device__ void OOBCheck(Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess]; + auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + using elm_vector_t = typename remove_cvref_t::type; + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + }); + } +#endif + + template + __device__ void + TransposeFromElmToDst(Number thread_scratch_id = Number{}) + { + using DstData = remove_cvref_t; + + using ElmThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + ElmThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_[thread_scratch_id]); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from + // dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number{}); + const auto scatter_offset = scatter_offsets_(Number{}); + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); + const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();//hack felix, todo use coord + // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + // dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + // if(threadIdx.x==0) + // printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(), scatter_offset ); + dst_bufs(i).template Update( + dst_offset, + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + // if(1) { + // static_for<0, DstScalarPerVector, 1>{}([&](auto idx) { + // using DstData = remove_cvref_t>; + // using print_vec_t = typename vector_type::type; + // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid, + // type_convert(dst_vectors[i].template AsType()[idx])); + // }); + // } + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + auto forward_step_scatter = [&]() constexpr + { + Index step_; + + static_for<0, nDim, 1>{}([&](auto i) { + step_(i) = i.value != ScatterDim ? forward_step[i] : 0; + + // if(threadIdx.x==0) + // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim); + }); + + return step_; + } + (); + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step_scatter)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(src_num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(dst_num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + auto reset_step_scatter = [&]() constexpr + { + Index step_; + static_for<0, nDim, 1>{}([&](auto i) { + step_(i) = i.value != ScatterDim ? reset_step[Number{}] : 0; + + // if(threadIdx.x==0) + // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim); + }); + + return step_; + } + (); + return reset_step_scatter; + } + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); + + using ElmVectorTuple = StaticallyIndexedArray; + using DstVectorTuple = StaticallyIndexedArray; + + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; + + using OOBVectorTuple = StaticallyIndexedArray; + StaticallyIndexedArray oob_vectors_tuple_; + + StaticallyIndexedArray scatter_offsets_; + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp new file mode 100644 index 0000000000..8633f061b1 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMoeGemm2 : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const index_t sorted_tile_size, + const Tensor& a_m_k, + const Tensor& b_e_n_k, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + sorted_tile_size_{sorted_tile_size}, + a_m_k_{a_m_k}, + b_e_n_k_{b_e_n_k}, + c_t_n_{c_t_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& expert_ids_; + const Tensor& sorted_token_ids_; + const Tensor& a_m_k_; + const Tensor& b_e_n_k_; + Tensor& c_t_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t sorted_tile_size_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMoeGemm2::Argument; + + float Run(const Argument& arg) + { + arg.c_t_n_.SetZero(); + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + AccDataType v_acc{0}; + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + const int t = arg.sorted_token_ids_(m); + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); + const int token_cnt = arg.c_t_n_.mDesc.GetLengths()[0]; + + if(t < token_cnt) { + for(int k = 0; k < K; ++k) + { + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_e_n_k_(e, n, k)); + } + else + { + arg.b_element_op_(v_b, arg.b_e_n_k_(e, n, k)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + CDataType v_c{0}; + + arg.c_element_op_(v_c, v_acc); + + arg.c_t_n_(t, n) += v_c; + } + + }; + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.a_m_k_.mDesc.GetLengths()[0], arg.c_t_n_.mDesc.GetLengths()[1])( + 1); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const index_t sorted_tile_size, + const Tensor& a_m_k, + const Tensor& b_e_n_k, + Tensor& c_t_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, c_t_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMoeGemm2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 7555a582ea..12c739e324 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -17,7 +17,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 --save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O1 -g --save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \