From 01020d0db46ee12678fa38e78308ea2f65b38572 Mon Sep 17 00:00:00 2001 From: rocking5566 Date: Fri, 11 Feb 2022 14:48:41 +0800 Subject: [PATCH] Support alpha beta scaling for GEMM (#78) * [What] Add 2d version of bias, prepare to implement alpha / beta scaling * Add alpha / beta functor * Refine parameter of example * [What] Use real type instead of template [Why] Prevent implicit cast * Rename parameter for general operator * Remove redundant comment * Fix compile error Co-authored-by: rocking Co-authored-by: Chao Liu [ROCm/composable_kernel commit: 6f928a08765e2110caf8ef20586c29d5e414ff71] --- .../element_wise_operation.hpp | 35 ++ device_operation/include/device_gemm.hpp | 29 + .../device_gemm_xdl_c_shuffle_bias_2d.hpp | 509 ++++++++++++++++++ example/8_gemm_xdl_alpha_beta/README.md | 59 ++ .../gemm_xdl_alpha_beta.cpp | 272 ++++++++++ example/CMakeLists.txt | 3 + 6 files changed, 907 insertions(+) create mode 100644 device_operation/include/device_gemm_xdl_c_shuffle_bias_2d.hpp create mode 100644 example/8_gemm_xdl_alpha_beta/README.md create mode 100644 example/8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp diff --git a/composable_kernel/include/tensor_operation/element_wise_operation.hpp b/composable_kernel/include/tensor_operation/element_wise_operation.hpp index d2054b8301..c2fe6a9f46 100644 --- a/composable_kernel/include/tensor_operation/element_wise_operation.hpp +++ b/composable_kernel/include/tensor_operation/element_wise_operation.hpp @@ -12,6 +12,41 @@ struct PassThrough __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; } }; +struct Add +{ + __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + } + + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + // FIXME - Use float (acc type) bias in the future. + y = x0 + x1; + } +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {} + + __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const + { + y = alpha_ * x0 + beta_ * x1; + } + + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + // FIXME - Let x0 be acc type + y = static_cast(alpha_ * static_cast(x0) + beta_ * static_cast(x1)); + } + + float alpha_; + float beta_; +}; + struct AddRelu { __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const diff --git a/device_operation/include/device_gemm.hpp b/device_operation/include/device_gemm.hpp index 5b386bd908..72b79e8531 100644 --- a/device_operation/include/device_gemm.hpp +++ b/device_operation/include/device_gemm.hpp @@ -8,6 +8,35 @@ namespace ck { namespace tensor_operation { namespace device { +template +struct DeviceGemmBias : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasPtr = std::unique_ptr< + DeviceGemmBias>; + template diff --git a/device_operation/include/device_gemm_xdl_c_shuffle_bias_2d.hpp b/device_operation/include/device_gemm_xdl_c_shuffle_bias_2d.hpp new file mode 100644 index 0000000000..6ee7967382 --- /dev/null +++ b/device_operation/include/device_gemm_xdl_c_shuffle_bias_2d.hpp @@ -0,0 +1,509 @@ +#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP +#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "device_gemm_xdl.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceGemmXdl_C_Shuffle_Bias_2d + : public DeviceGemmBias +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_k0_m_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_k0_m_k1; + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_k0_n_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_k0_n_k1; + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const CDataType* p_bias_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c0_grid_{p_bias_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c0_grid_desc_m_n_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceGemmXdl_C_Shuffle_Bias_2d::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceGemmXdl_C_Shuffle_Bias_2d::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c0_grid_desc_m_n_ = + DeviceGemmXdl_C_Shuffle_Bias_2d::MakeCGridDescriptor_M_N(M, N, StrideC); + c_grid_desc_m_n_ = + DeviceGemmXdl_C_Shuffle_Bias_2d::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const CDataType* p_c0_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + C0GridDesc_M_N c0_grid_desc_m_n_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const CDataType* p_bias, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_bias, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_bias), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/example/8_gemm_xdl_alpha_beta/README.md b/example/8_gemm_xdl_alpha_beta/README.md new file mode 100644 index 0000000000..a3dc4a75fc --- /dev/null +++ b/example/8_gemm_xdl_alpha_beta/README.md @@ -0,0 +1,59 @@ +# Instructions for ```gemm_xdl_alpha_beta``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```gemm_xdl_alpha_beta``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j gemm_xdl_alpha_beta +``` + +## Run ```gemm_xdl_alpha_beta``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +./example/gemm_xdl_alpha_beta 1 1 1 0.5 0.5 +``` +Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c0_grid_desc_m_n_{ 3840, 4096} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s +error: 0 +max_diff: 0, 558.5, 558.5 +``` diff --git a/example/8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp b/example/8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp new file mode 100644 index 0000000000..2a7b6991e2 --- /dev/null +++ b/example/8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp @@ -0,0 +1,272 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" + +template +using S = ck::Sequence; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_2d< + ADataType, // ADataType + BDataType, // BDataType + CDataType, // CDataType + AccDataType, // AccDataType + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +template +static void host_verify(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& c0_k_n, + Tensor& c_m_n, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) +{ + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a_m_k.mDesc.GetLengths()[1]; + + AccDataType v = 0; + AccDataType a = 0; + AccDataType b = 0; + for(int k = 0; k < K; ++k) + { + a_element_op(a, a_m_k(m, k)); + b_element_op(b, b_k_n(k, n)); + v += a * b; + } + + CType y = static_cast(v); + + c_element_op(c_m_n(m, n), y, c0_k_n(m, n)); + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, + c_m_n.mDesc.GetLengths()[0], + c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); +} + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + beta = std::stof(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + c0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c0_m_n_device_buf(sizeof(CDataType) * c0_m_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c0_m_n_device_buf.ToDevice(c0_m_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c0_m_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + AElementOp{}, + BElementOp{}, + CElementOp{alpha, beta}); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + host_verify(a_m_k, + b_k_n, + c0_m_n, + c_m_n_host_result, + AElementOp{}, + BElementOp{}, + CElementOp{alpha, beta}); + + check_error(c_m_n_host_result, c_m_n_device_result); + } +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index f9474425bc..998e9b3578 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -19,6 +19,7 @@ set(CONV2D_FWD_XDL_SOURCE 4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_SOURCE 5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp) +set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp) add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE}) @@ -27,6 +28,7 @@ add_executable(conv2d_fwd_xdl ${CONV2D_FWD_XDL_SOURCE}) add_executable(conv2d_fwd_xdl_bias_relu ${CONV2D_FWD_XDL_BIAS_RELU_SOURCE}) add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE}) add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE}) +add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE}) target_link_libraries(gemm_xdl PRIVATE host_tensor) target_link_libraries(gemm_xdl_bias_relu PRIVATE host_tensor) @@ -35,3 +37,4 @@ target_link_libraries(conv2d_fwd_xdl PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor) +target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor)