diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index a671877f68..3dc2a0966e 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -6,8 +6,7 @@ if(DL_KERNELS) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_dependencies(example_gemm_dl example_gemm_dl_fp16) - add_example_executable(example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_dpp8_fp16) + add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) endif() if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) diff --git a/example/01_gemm/gemm_dl_dpp8_fp16.cpp b/example/01_gemm/gemm_dl_dpp8_fp16.cpp deleted file mode 100644 index ea0ba39076..0000000000 --- a/example/01_gemm/gemm_dl_dpp8_fp16.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using ALayout = Col; -using BLayout = Row; -using CLayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CElementOp = PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDlDpp8 -// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| -// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| -// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dpp_fp16.cpp b/example/01_gemm/gemm_dpp_fp16.cpp new file mode 100644 index 0000000000..7a9e3f6186 --- /dev/null +++ b/example/01_gemm/gemm_dpp_fp16.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CDataType = ck::half_t; + +using F16 = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 64, 64, 64, 8, 2, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 5, 1>; +// // clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp deleted file mode 100644 index e527509f57..0000000000 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp +++ /dev/null @@ -1,370 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/amd_gemm_dpp.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp" - -namespace ck { - -/** - * DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit - * the data loaded from LDS to registers. - * - * The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C - * between them in such a way that threads from the same group need the same chunk of either - * matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the - * whole chunk from LDS to its own register space. - * Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size` - * of the chunk, and then share that data with other threads from the same lane group. - * - * Assumptions coming from the usage of DPP8: - * 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or - * `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` - - * - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either - * matrix A or B; - * - based on these values we determine which matrix to share. - * 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or - * `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) - - * - we have to make sure that the data to split is divisible by the number of - * threads in the group. - * - * General algorithm: - * C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] - * A and B are visible to the whole block, C is distributed among each thread - * Assume: - * 1. A: - * 1. ABlockDesc_BK0_BM_BK1 is known at compile-time - * 2. ABlockBuffer is DynamicBuffer - * 2. B: - * 1. BBlockDesc_BK0_BN_BK1 is known at compile-time - * 2. BBlockBuffer is DynamicBuffer - * 3. C: - * 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time - * 2. CThreadBuffer is StaticBuffer - * 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 - */ -template - typename BM10BN10ThreadClusterBN10Xs, // Sequence - index_t AThreadCopyScalarPerVector_BM11, - index_t BThreadCopyScalarPerVector_BN11, - typename enable_if::type = false> -struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0 -{ - using AIndex = MultiIndex<4>; - using BIndex = MultiIndex<4>; - using CIndex = MultiIndex<4>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); - static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); - static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); - static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); - - static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; - static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; - - static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; - static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; - - static constexpr index_t BM11 = BM1PerThreadBM11; - static constexpr index_t BN11 = BN1PerThreadBN11; - - static constexpr index_t BM1 = BM100 * BM101 * BM11; - static constexpr index_t BN1 = BN100 * BN101 * BN11; - - static constexpr index_t BM0 = BM / BM1; - static constexpr index_t BN0 = BN / BN1; - - // We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all - // threads in a lane group need the same chunk of B or A matrices and we can share them using - // DPP. - static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size); - static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false; - static constexpr bool ShareA = !ShareB; - - // If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`, - // respectively) elements, so we split them between threads in lane group so each thread loads - // less data from LDS. - static constexpr index_t BM1PerThread = - ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11; - static constexpr index_t BN1PerThread = - ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11; - - __host__ __device__ static constexpr auto - MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) - { - const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor( - a_block_desc_bk0_bm_bk1, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return a_block_bk0_bm0_bm1_bk1; - } - - __host__ __device__ static constexpr auto - MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) - { - const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor( - b_block_desc_bk0_bn_bk1, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return b_block_desc_bk0_bn0_bn1_bk1; - } - - __host__ __device__ static constexpr auto - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() - { - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // lower: [BM, BN] - constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); - - return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; - } - - __host__ __device__ static constexpr auto - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() - { - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // lower: [BM0, BM1, BN0, BN1] - constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = - make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); - - return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; - } - - __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() - { - return Sequence{}; - } - - static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = - MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); - - static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = - MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); - - public: - __device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0() - : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( - get_thread_local_1d_id())}, - a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()}, - b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()} - { - static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && - BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); - - static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == - BBlockDesc_BK0_BN_BK1{}.GetLength(I0), - "wrong! K dimension not consistent"); - - static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && - BM10BN10ThreadClusterBN10Xs::Size() == 2, - "wrong!"); - } - - __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) - { - // lower: [BM0, BM1, BN0, BN1] - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - constexpr auto adaptor0 = - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); - - // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // upper: [Tid, BM0, BM11, BN0, BN11] - constexpr auto adaptor1 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), - make_pass_through_transform(BM0), - make_pass_through_transform(BM11), - make_pass_through_transform(BN0), - make_pass_through_transform(BN11)), - make_tuple( - Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); - - return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); - } - - __device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1() - { - const auto offsetBM0 = c_thread_origin_data_idx_[I0]; - // If sharing matrix A, we need a separate BM1 offset for each thread in lane group. - const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] + - dpp8::get_thread_idx_in_lane_group() * BM1PerThread - : c_thread_origin_data_idx_[I1]; - return make_tuple(0, offsetBM0, offsetBM1, 0); - } - - __device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1() - { - const auto offsetBN0 = c_thread_origin_data_idx_[I2]; - // If sharing matrix B, we need a separate BN1 offset for each thread in lane group. - const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] + - dpp8::get_thread_idx_in_lane_group() * BN1PerThread - : c_thread_origin_data_idx_[I3]; - return make_tuple(0, offsetBN0, offsetBN1, 0); - } - - template - __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, - const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - auto a_thread_buf = make_static_buffer( - a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); - - constexpr auto threadwise_contraction = - ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< - FloatA, - FloatB, - FloatC, - decltype(a_thread_desc_bk0_bm0_bm1_bk1_), - decltype(b_thread_desc_bk0_bn0_bn1_bk1_), - CThreadDesc_BM0_BM11_BN0_BN11, - Sequence, - Sequence<1, BM1PerThreadBM11>, - Sequence<1, BN1PerThreadBN11>, - ShareA>{}; - - static_for<0, BN0, 1>{}([&](auto bn0) { - static_for<0, BM0, 1>{}([&](auto bm0) { - a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, bm0, I0, I0), - a_block_buf, - a_thread_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, bn0, I0, I0), - b_block_buf, - b_thread_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - threadwise_contraction.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(bm0, I0, bn0, I0)); - - static_for{}([&](auto bk0) { - a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, - make_tuple(bk0, bm0, I0, I0), - a_block_buf, - a_thread_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, - make_tuple(bk0, bn0, I0, I0), - b_block_buf, - b_thread_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - threadwise_contraction.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(bm0, I0, bn0, I0)); - }); - }); - }); - } - - private: - // A[BK0, BM0, BM1, BK1] - static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{}, Number{})); - - // B[BK0, BN0, BN1, BK1] - static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< - FloatA, - FloatA, - decltype(a_block_desc_bk0_bm0_bm1_bk1_), - decltype(a_thread_desc_bk0_bm0_bm1_bk1_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< - FloatB, - FloatB, - decltype(b_block_desc_bk0_bn0_bn1_bk1_), - decltype(b_thread_desc_bk0_bn0_bn1_bk1_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder - - CIndex c_thread_origin_data_idx_; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp new file mode 100644 index 0000000000..d62ed4b15d --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp" + +namespace ck { + +/** + * Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each + * thread by sharing the data between threads in a lanegroup. + * + * In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are + * `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one. + * In total, the algorithm runs using + * `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves. + */ +template +struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto dpp_gemm = DppGemm{}; + + static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M(); + const auto dpp_a_idx_k = dpp_a_idx[I0]; + const auto dpp_a_idx_m = dpp_a_idx[I1]; + return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k); + } + + __device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N(); + const auto dpp_b_idx_k = dpp_b_idx[I0]; + const auto dpp_b_idx_n = dpp_b_idx[I1]; + return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk(); + const auto blk_m_offset = blk_idx[I0]; + const auto blk_n_offset = blk_idx[I1]; + + constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_m_offset))[I0]; + const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_n_offset))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "Wrong! Block descriptors should be known at the time of compilation."); + +#if defined(__HIP_DEVICE_COMPILE__) + // Host wave size can be different than the device one and this assert could fail for host, + // but it does matter only for device. + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); +#endif + + static_assert(MPerBlock % (MPerDpp * MRepeat) == 0, + "Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat."); + static_assert(NPerBlock % (NPerDpp * NRepeat) == 0, + "Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat."); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return c_block_desc_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + return c_block_desc_g_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return c_grid_desc_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return c_grid_desc_g_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using dpp_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + dpp_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegDpp] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, dpp_gemm.GetRegSizePerDpp())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index d5a64d7aa6..1fee9c3225 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -4,27 +4,13 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { -enum struct LoopScheduler -{ - Default, - Interwave, -}; - -constexpr LoopScheduler make_default_loop_scheduler() -{ -#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING - return LoopScheduler::Interwave; -#else - return LoopScheduler::Default; -#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING -} - template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) diff --git a/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp b/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp deleted file mode 100644 index 16ca582b89..0000000000 --- a/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp +++ /dev/null @@ -1,18 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -namespace ck { -namespace tensor_operation { -namespace device { - -enum struct GemmDlAlgorithm -{ - Default, // Uses DOT vector instructions - Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index eedf384cd9..d9f8324bd2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -11,7 +11,6 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/host_utility/device_prop.hpp" @@ -60,7 +59,6 @@ template < typename CThreadTransferSrcDstAccessOrder, index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferDstScalarPerVector, - GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default, enable_if_t< is_same_v && is_same_v && @@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm; + CThreadTransferDstScalarPerVector>; using AGridDesc_K0_M0_M1_K1 = decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); @@ -375,8 +372,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - true, - GemmDlAlg>; + true>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -402,8 +398,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - false, - GemmDlAlg>; + false>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -429,8 +424,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - true, - GemmDlAlg>; + true>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -456,8 +450,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - false, - GemmDlAlg>; + false>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -492,16 +485,6 @@ struct DeviceGemmDl : public DeviceGemm -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - GemmSpecialization GemmSpec, - index_t BlockSize, - index_t MPerBlock, - index_t NPerBlock, - index_t K0PerBlock, - index_t K1, - index_t M1PerThread, - index_t N1PerThread, - index_t KPerThread, - typename M1N1ThreadClusterM1Xs, - typename M1N1ThreadClusterN1Xs, - typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, - typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, - typename ABlockTransferSrcVectorTensorContiguousDimOrder, - typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, - typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, - typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, - typename BBlockTransferSrcVectorTensorContiguousDimOrder, - typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, - typename CThreadTransferSrcDstAccessOrder, - index_t CThreadTransferSrcDstVectorDim, - index_t CThreadTransferDstScalarPerVector, - enable_if_t< - is_same_v && - is_same_v && - is_same_v, - bool> = false> -struct DeviceGemmDlDpp8 : public DeviceGemmDl - -{ - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmDlDpp8" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock << ", " - << K1 << ", " - << M1PerThread << ", " - << N1PerThread << ", " - << KPerThread - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp new file mode 100644 index 0000000000..162f0690b8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmDpp : public DeviceGemm +{ + using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< + BlockSize, + ADataType, + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + MPerBlock, + NPerBlock, + KPerBlock, + MPerDpp, + NPerDpp, + AK1, + BK1, + MDppPerWave, + NDppPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting"); + } + + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + else + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(ck::get_device_name() == "gfx1030") + { + return GridwiseGemm::CheckValidity(karg); + } + return false; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmDpp" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerDpp << ", " + << NPerDpp << ", " + << MDppPerWave << ", " + << MDppPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 1d1bb6ed2d..1da7236978 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -7,11 +7,9 @@ #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" @@ -19,8 +17,6 @@ namespace ck { -using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm; - template + bool HasDoubleTailKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -43,13 +38,6 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -// DPP8 is currently only supported on gfx1030 -#if !defined(__gfx1030__) - if(GemmDlAlg == GemmDlAlgorithm::Dpp8) - { - return; - } -#endif constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -100,8 +88,7 @@ template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemmDl_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; @@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 c_grid_desc_m_n); } - template - __host__ __device__ static constexpr auto GetBlockwiseGemm() - { - if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8) - { - return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - ABlockDesc_BK0_BM_BK1, - BBlockDesc_BK0_BN_BK1, - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - } - else - { - return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - ABlockDesc_BK0_BM_BK1, - BBlockDesc_BK0_BN_BK1, - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - } - } - using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using CGridDesc_M0_M10_M11_N0_N10_N11 = @@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register const auto blockwise_gemm = - GetBlockwiseGemm(); + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp new file mode 100644 index 0000000000..d39f0ba30b --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -0,0 +1,701 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( + GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA)); + const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane( + GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane( + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_m_n); +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + static constexpr auto max_lds_align = math::lcm(AK1, BK1); + + using ThisThreadBlock = ThisThreadBlock; + // return block_id to C matrix tile idx (m0, n0) mapping + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); } + __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + AK0{CalculateAK0(K)}, + BK0{CalculateBK0(K)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t AK0; + index_t BK0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const ABDataType* p_a_grid_, + const ABDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ABDataType* p_a_grid; + const ABDataType* p_b_grid; + CDataType* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType); + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "Wrong! AK1 must be known at the time of compilation."); + static_assert(is_known_at_compile_time>::value, + "Wrong! BK1 must be known at the time of compilation."); + + static_assert( + MPerBlock % (MPerDpp * MDppPerWave) == 0, + "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave."); + static_assert( + NPerBlock % (NPerDpp * NDppPerWave) == 0, + "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave."); + + static_assert( + KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1."); + + static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! AK1Value must be divisible by " + "ABlockTransferDstScalarPerVector_K1"); + + static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! BK1Value must be divisible by " + "BBlockTransferDstScalarPerVector_K1"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if(problem.K % KPerBlock != 0) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = problem.K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const auto num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n) + { + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + + using BlockwiseGemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + } + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + __device__ static auto + MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __device__ static auto + MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(BK0, BK1Value))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + } + + __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[AK0PerBlock, MPerBlock] is in LDS + // b_mtx[BK0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + auto blockwise_gemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0); + // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock) + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index d1209636de..754a3e89c9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -4,7 +4,8 @@ #pragma once #include "ck/utility/common_header.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp deleted file mode 100644 index d0d214381d..0000000000 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp +++ /dev/null @@ -1,136 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/amd_gemm_dpp.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/utility/inner_product_dpp8.hpp" -#include "ck/utility/math.hpp" - -namespace ck { - -/** - * Threadwise contraction using dot instructions with DPP8 modifier. - * - * Assumptions: - * 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1` - * are known at compile-time; - * 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time; - * 3. `TM0` is equal to 1 and `TN0` is equal to 1; - * 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by - * the size of the lane group (`dpp8::lane_group_size`). - */ -template ::type = false> -struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 -{ - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t TK0 = TKLengths{}[I0]; - static constexpr index_t TK1 = TKLengths{}[I1]; - static constexpr index_t TM0 = TMLengths{}[I0]; - static constexpr index_t TM1 = TMLengths{}[I1]; - static constexpr index_t TN0 = TNLengths{}[I0]; - static constexpr index_t TN1 = TNLengths{}[I1]; - - static_assert(TM0 == 1 && TN0 == 1); - - static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) || - (!ShareA && TN1 % dpp8::lane_group_size == 0)); - static constexpr index_t shared_elems_per_lane = - ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size; - - __device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() - { - static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && - BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && - CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, - "wrong!"); - } - - template - __device__ static void Run(const ABuffer& a_buf, - AOriginIdx, - const BBuffer& b_buf, - BOriginIdx, - CBuffer& c_buf, - COriginIdx) - { - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); - constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); - constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - - static_for<0, TK0, 1>{}([&](auto tk0) { - static_for<0, TM1, 1>{}([&](auto tm1) { - static_for<0, TN1, 1>{}([&](auto tn1) { - vector_type a_vec; - vector_type b_vec; - - static_for<0, TK1, 1>{}([&](auto tk1) { - constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1; - constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( - a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1)); - - constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane; - constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( - b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1)); - - a_vec.template AsType()(tk1) = a_buf[Number{}]; - b_vec.template AsType()(tk1) = b_buf[Number{}]; - }); - - using a_vector_t = typename vector_type::type; - using b_vector_t = typename vector_type::type; - - constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( - c_origin_idx + make_multi_index(0, tm1, 0, tn1)); - - constexpr int src_lane = - ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size - : (tn1 / shared_elems_per_lane) % dpp8::lane_group_size; - - dpp8::inner_product_dpp( - a_vec.template AsType()[I0], - b_vec.template AsType()[I0], - c_buf(Number{})); - }); - }); - }); - } -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp new file mode 100644 index 0000000000..03a4d17c9b --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +enum struct DppInstr +{ + dpp8_f16_16x16x2 = 0, + dpp8_f16_8x32x2, + dpp8_f16_32x8x2 +}; + +/** + * Structure representing DPP GEMM executed by a single wavefront. + * + * Each structure instantiation must contain the following fields: + * - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the + * number of threads in a wavefront; + * - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier, + * it's 8 in case of DPP8; + * - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - m_per_lanegroup - size along M dimension that is processed by a single lanegroup; + * - n_per_lanegroup - size along N dimension that is processed by a single lanegroup; + * - m_per_thread - size along M dimension of the tile calculated by a single thread; + * - n_per_thread - size along N dimension of the tile calculated by a single thread; + * - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation; + * - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers. + * + * Not all the combinarions are supported now, for current restrictions see the static asserts + * in the DppSelector's contructor. + */ +template +struct dpp_type; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 32; + static constexpr index_t n_per_wave = 8; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 16; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template +struct DppSelector +{ + template + static constexpr auto GetDpp(); + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_16x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_32x8x2; + } + + static constexpr auto selected_dpp = dpp_type()>{}; + + __host__ __device__ constexpr DppSelector() + { + static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0); + static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0); + + static_assert(selected_dpp.k_per_dpp % 2 == 0); + + static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0); + constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size; + constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave; + constexpr index_t num_dpp_c_elems = + selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup; + static_assert(num_wave_c_elems % num_dpp_c_elems == 0); + static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems); + + if constexpr(selected_dpp.share_a) + { + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + } + else + { + static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread); + } + + // Below checks come from the restrictions of the current implementation, could be removed + // in the future when the implementation is more generalized. + static_assert(selected_dpp.share_a); + static_assert(selected_dpp.n_per_thread == 1); + static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size); + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup == + selected_dpp.n_per_thread * selected_dpp.lanegroup_size); + } + + static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; } +}; + +template +struct DppGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __host__ __device__ constexpr DppGemm() + { + static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32, + "MPerDpp must be either 8, 16 or 32."); + static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32, + "NPerDpp must be either 8, 16 or 32."); + + static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); + } + + __device__ static constexpr index_t GetRegSizePerDpp() + { + return MPerDpp * NPerDpp / dpp_instr.wave_size; + } + + template + __device__ void + Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "base BaseType must be double, float, half, bfloat16, and int8_t!"); + + static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) { + dpp_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + }); + } + + __device__ static auto GetLaneIdInWave() + { + return get_thread_local_1d_id() % dpp_instr.wave_size; + } + + __device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; } + + __device__ static auto GetLaneIdInLaneGroup() + { + return get_thread_local_1d_id() % dpp_instr.lanegroup_size; + } + + __device__ static auto GetLaneGroupIdInWave() + { + return GetLaneIdInWave() / dpp_instr.lanegroup_size; + } + + __device__ static auto GetDppOpIdx() + { + const auto lanegroupId = GetLaneGroupIdInWave(); + + constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup, + dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex( + make_multi_index(lanegroupId)); + + const auto m_dpp_idx = dpp_idx[I0]; + const auto n_dpp_idx = dpp_idx[I1]; + + return make_tuple(m_dpp_idx, n_dpp_idx); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M() + { + const auto laneId = get_thread_local_1d_id(); + const auto wave_row = laneId / dpp_instr.n_per_wave; + auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup(); + return make_tuple(0, m_idx % dpp_instr.m_per_wave); + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N() + { + const auto laneId = get_thread_local_1d_id(); + return make_tuple(0, laneId % dpp_instr.n_per_wave); + } + + __device__ static CIndex GetBeginOfThreadBlk() + { + const auto dpp_op_idx = GetDppOpIdx(); + + const auto m_dpp_op_idx = dpp_op_idx[I0]; + const auto n_dpp_op_idx = dpp_op_idx[I1]; + + index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup(); + index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup; + + return CIndex{m_offset, n_offset}; + } + + static constexpr auto dpp = DppSelector{}; + + static constexpr auto dpp_instr = dpp.selected_dpp; + + static constexpr auto K0PerDpp = 1; + static constexpr auto K1PerDpp = dpp.GetK1PerDpp(); + + __host__ __device__ static constexpr auto GetCMNThreadBlkLengths() + { + return make_tuple(Number{}, Number{}); + } +}; + +} // namespace ck diff --git a/include/ck/utility/amd_gemm_dpp.hpp b/include/ck/utility/amd_gemm_dpp.hpp index 8d6c7eede9..a28292dade 100644 --- a/include/ck/utility/amd_gemm_dpp.hpp +++ b/include/ck/utility/amd_gemm_dpp.hpp @@ -5,17 +5,63 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/math.hpp" -#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/inner_product_dpp8.hpp" namespace ck { namespace dpp8 { -/// Number of lanes that can share data using DPP8 modifiers. -constexpr index_t lane_group_size = 8; +template +struct dpp_datatypes; -__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; } -__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; } +template <> +struct dpp_datatypes +{ + // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a + // single instruction. + using a_dtype = half_t; + using b_dtype = half_t; + using c_dtype = float; + static constexpr index_t k_per_instr = 2; +}; + +template +struct DppLanegroupGemm +{ + using datatypes_conf = dpp_datatypes; + using ADataType = typename datatypes_conf::a_dtype; + using BDataType = typename datatypes_conf::b_dtype; + using CDataType = typename datatypes_conf::c_dtype; + + __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec) + { + constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread; + + const vector_type a_vector{a_vec}; + const vector_type b_vector{b_vec}; + + static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) { + float c = c_vec.template AsType()(c_idx); + // Next `c_idx` implies that we need to pull data from the next lane. + constexpr index_t source_lane = c_idx; + static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) { + const auto a_k_vec = a_vector.template AsType()[k_chunk]; + const auto b_k_vec = b_vector.template AsType()[k_chunk]; + ck::dpp8:: + inner_product_dpp( + a_k_vec, b_k_vec, c); + }); + c_vec.template AsType()(c_idx) = c; + }); + } +}; } // namespace dpp8 diff --git a/include/ck/utility/inner_product_dpp8.hpp b/include/ck/utility/inner_product_dpp8.hpp index ccd7a4e628..f079e2ca64 100644 --- a/include/ck/utility/inner_product_dpp8.hpp +++ b/include/ck/utility/inner_product_dpp8.hpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once + #include "amd_gemm_dpp.hpp" #include "data_type.hpp" #include "type_convert.hpp" @@ -10,6 +11,9 @@ namespace ck { namespace dpp8 { +/// Number of lanes that can share data using DPP8 modifiers. +constexpr index_t lane_group_size = 8; + template __device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp new file mode 100644 index 0000000000..b2eb0ddb93 --- /dev/null +++ b/include/ck/utility/loop_scheduler.hpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +enum struct LoopScheduler +{ + Default, + Interwave, +}; + +constexpr LoopScheduler make_default_loop_scheduler() +{ +#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING + return LoopScheduler::Interwave; +#else + return LoopScheduler::Default; +#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING +} + +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index 9310fd433b..b4fa746485 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -23,7 +23,7 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances( +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( std::vector>>& instances); @@ -38,7 +38,7 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances( +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( std::vector>>& instances); @@ -53,7 +53,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances( +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances); @@ -68,7 +68,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances( +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances); @@ -374,7 +374,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } @@ -385,7 +385,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); @@ -397,7 +397,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); } @@ -408,7 +408,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); } diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 5162617b18..48dd292f05 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -31,10 +31,10 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp) endif() list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp deleted file mode 100644 index 53fc74a392..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp deleted file mode 100644 index 35d46b32f5..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 22e2e5f398..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0 ,3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp deleted file mode 100644 index 49e9416c81..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0 ,3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..80da6d4c3b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_km_kn_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 4, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 4, 32, 8, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 4, 4, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 4, 4, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 4, 4, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 4, 4, 32, 8, 1, 4, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 4, 4, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..97b85fd1a0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_km_nk_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 8, 32, 8, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 4, 8, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 4, 8, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 4, 8, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 4, 8, 32, 8, 1, 4, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..370ebfccc4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 4, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 4, 32, 8, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 8, 4, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 8, 4, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 8, 4, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 8, 4, 32, 8, 1, 4, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..6053275492 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 8, 32, 8, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 8, 8, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 8, 8, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 8, 8, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 8, 8, 32, 8, 1, 4, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck