From 7094d7c9108cbbd10c7fe2dddb930360dd7de887 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Thu, 30 Jun 2022 19:55:09 -0500 Subject: [PATCH] Gemm + bias + c_permute (#312) * init commit * add desc * finished c permute * fixed vector lens [ROCm/composable_kernel commit: fa9a0a5cfbc5daaad5650403725679971d79cb1e] --- example/25_gemm_bias_c_permute/CMakeLists.txt | 1 + .../gemm_bias_c_permute_xdl_fp16.cpp | 284 +++++++ example/CMakeLists.txt | 1 + .../gpu/device/device_gemm_bias_c_permute.hpp | 57 ++ .../device/device_gemm_bias_c_permute_xdl.hpp | 761 ++++++++++++++++++ .../element/binary_element_wise_operation.hpp | 11 +- 6 files changed, 1113 insertions(+), 2 deletions(-) create mode 100644 example/25_gemm_bias_c_permute/CMakeLists.txt create mode 100644 example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp diff --git a/example/25_gemm_bias_c_permute/CMakeLists.txt b/example/25_gemm_bias_c_permute/CMakeLists.txt new file mode 100644 index 0000000000..29b1d94b3c --- /dev/null +++ b/example/25_gemm_bias_c_permute/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_bias_c_permute_xdl_fp16 gemm_bias_c_permute_xdl_fp16.cpp) diff --git a/example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp b/example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp new file mode 100644 index 0000000000..e7a439ca34 --- /dev/null +++ b/example/25_gemm_bias_c_permute/gemm_bias_c_permute_xdl_fp16.cpp @@ -0,0 +1,284 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasCPermute_Xdl +//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::index_t M0 = 4; + ck::index_t M1 = 32; + ck::index_t M2 = 128; + ck::index_t N0 = 16; + ck::index_t N1 = 256; + + // GEMM shape + ck::index_t M = M0 * M1 * M2; + ck::index_t N = N0 * N1; + ck::index_t K = 128; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + +#if 1 + // E = [M0, N0, M1, N1, M2] + ck::index_t stride_E_M0 = N0 * M1 * N1 * M2; + ck::index_t stride_E_M1 = N1 * M2; + ck::index_t stride_E_M2 = 1; + ck::index_t stride_E_N0 = M1 * N1 * M2; + ck::index_t stride_E_N1 = M2; + + // D = [0, N0, 0, N1, 0] + ck::index_t stride_D_M0 = 0; + ck::index_t stride_D_M1 = 0; + ck::index_t stride_D_M2 = 0; + ck::index_t stride_D_N0 = N1; + ck::index_t stride_D_N1 = 1; +#else + // D = [0, 0, 0, N0, N1] + ck::index_t stride_D_M0 = 0; + ck::index_t stride_D_M1 = 0; + ck::index_t stride_D_M2 = 0; + ck::index_t stride_D_N0 = N1; + ck::index_t stride_D_N1 = 1; + + // E = [M0, M1, M2, N0, N1] + ck::index_t stride_E_M0 = M1 * M2 * N0 * N1; + ck::index_t stride_E_M1 = M2 * N0 * N1; + ck::index_t stride_E_M2 = N0 * N1; + ck::index_t stride_E_N0 = N1; + ck::index_t stride_E_N1 = 1; +#endif + + const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc{ + M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1}; + const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc{ + M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + auto f_host_de_tensor_descriptor = + [](ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 de_grid_desc) { + std::size_t m0 = de_grid_desc.M0_; + std::size_t m1 = de_grid_desc.M1_; + std::size_t m2 = de_grid_desc.M2_; + std::size_t n0 = de_grid_desc.N0_; + std::size_t n1 = de_grid_desc.N1_; + std::size_t stride_m0 = de_grid_desc.stride_M0_; + std::size_t stride_m1 = de_grid_desc.stride_M1_; + std::size_t stride_m2 = de_grid_desc.stride_M2_; + std::size_t stride_n0 = de_grid_desc.stride_N0_; + std::size_t stride_n1 = de_grid_desc.stride_N1_; + return HostTensorDescriptor( + std::vector({m0, m1, m2, n0, n1}), + std::vector({stride_m0, stride_m1, stride_m2, stride_n0, stride_n1})); + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + Tensor d_m0_m1_m2_n0_n1(f_host_de_tensor_descriptor(d_grid_desc)); + Tensor e_m0_m1_m2_n0_n1_host_result(f_host_de_tensor_descriptor(e_grid_desc)); + Tensor e_m0_m1_m2_n0_n1_device_result(f_host_de_tensor_descriptor(e_grid_desc)); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m0_m1_m2_n0_n1: " << d_m0_m1_m2_n0_n1.mDesc << std::endl; + std::cout << "e_m0_m1_m2_n0_n1: " << e_m0_m1_m2_n0_n1_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) * + d_m0_m1_m2_n0_n1.mDesc.GetElementSpace()); + DeviceMem e_m0_m1_m2_n0_n1_device_buf(sizeof(EDataType) * + e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + d_m0_m1_m2_n0_n1_device_buf.ToDevice(d_m0_m1_m2_n0_n1.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), + b_k_n_device_buf.GetDeviceBuffer(), + d_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(), + e_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m0 = 0; m0 < M0; ++m0) + for(int m1 = 0; m1 < M1; ++m1) + for(int m2 = 0; m2 < M2; ++m2) + for(int n0 = 0; n0 < N0; ++n0) + for(int n1 = 0; n1 < N1; ++n1) + { + int m = m0 * M1 * M2 + m1 * M2 + m2; + int n = n0 * N1 + n1; + + cde_element_op(e_m0_m1_m2_n0_n1_host_result(m0, m1, m2, n0, n1), + ck::type_convert(c_m_n(m, n)), + d_m0_m1_m2_n0_n1(m0, m1, m2, n0, n1)); + } + + e_m0_m1_m2_n0_n1_device_buf.FromDevice(e_m0_m1_m2_n0_n1_device_result.mData.data()); + + return ck::utils::check_err(e_m0_m1_m2_n0_n1_device_result.mData, + e_m0_m1_m2_n0_n1_host_result.mData) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 9bba66ad0b..1c568c9d18 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -42,3 +42,4 @@ add_subdirectory(20_convnd_bwd_weight_xdl) add_subdirectory(21_gemm_layernorm) add_subdirectory(22_cgemm) add_subdirectory(23_softmax) +add_subdirectory(25_gemm_bias_c_permute) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp new file mode 100644 index 0000000000..bde0d48c15 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DEGridDesc_M0_M1_M2_N0_N1 +{ + ck::index_t M0_, M1_, M2_, N0_, N1_; + ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_; +}; + +// input : A[M, K], B[K, N], +// input : D[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D) +template +struct DeviceGemmBiasCPermute : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasCPermutePtr = std::unique_ptr< + DeviceGemmBiasCPermute>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp new file mode 100644 index 0000000000..f74cb0dc84 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/device_utility/device_prop.hpp" +#include "ck/device_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_bias_c_permute(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// input : A[M, K], or A[K, N] +// input : B[K, N], or A[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute +{ + using DeviceOp = DeviceGemmBiasCPermute_Xdl; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t NumDTensor = I1; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc) + { + index_t M0 = d_e_grid_desc.M0_; + index_t M1 = d_e_grid_desc.M1_; + index_t M2 = d_e_grid_desc.M2_; + index_t N0 = d_e_grid_desc.N0_; + index_t N1 = d_e_grid_desc.N1_; + + index_t stride_M0 = d_e_grid_desc.stride_M0_; + index_t stride_M1 = d_e_grid_desc.stride_M1_; + index_t stride_M2 = d_e_grid_desc.stride_M2_; + index_t stride_N0 = d_e_grid_desc.stride_N0_; + index_t stride_N1 = d_e_grid_desc.stride_N1_; + + const auto MRaw = M0 * M1 * M2; + const auto NRaw = N0 * N1; + + const auto c_grid_desc_mraw_nraw = [&]() { + const auto c_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor( + make_tuple(M0, M1, M2, N0, N1), + make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1)); + + return transform_tensor_descriptor( + c_grid_desc_m0_m1_m2_n0_n1, + make_tuple(make_merge_transform(make_tuple(M0, M1, M2)), + make_merge_transform(make_tuple(N0, N1))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{})); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + const void* p_d_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + + if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + p_ds_grid_(I0) = static_cast(p_d_grid); + + const auto d_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_bias_c_permute< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_d, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_d, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + d_grid_desc, + e_grid_desc, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasCPermute_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index e572f4fa00..9824ad532a 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -11,8 +11,8 @@ namespace element_wise { struct Add { - template - __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; template <> __host__ __device__ constexpr void @@ -28,6 +28,13 @@ struct Add y = x0 + x1; }; + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) + x1; + }; + // Question: should half_t be supported ? template <> __host__ __device__ constexpr void