diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index f8b886ebbf..a5933262a5 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -6,6 +6,8 @@ if(DL_KERNELS) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_dependencies(example_gemm_dl example_gemm_dl_fp16) + add_example_executable(example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_dpp8_fp16) endif() if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) diff --git a/example/01_gemm/gemm_dl_dpp8_fp16.cpp b/example/01_gemm/gemm_dl_dpp8_fp16.cpp new file mode 100644 index 0000000000..ea0ba39076 --- /dev/null +++ b/example/01_gemm/gemm_dl_dpp8_fp16.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDlDpp8 +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 950121feb3..0e7e690421 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -125,6 +125,9 @@ // `s_nop`s to avoid hazard #define CK_USE_AMD_V_DOT_INLINE_ASM 0 +// inner product using V_DOT with DPP8 modifiers +#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 + // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp new file mode 100644 index 0000000000..e527509f57 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp @@ -0,0 +1,370 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp" + +namespace ck { + +/** + * DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit + * the data loaded from LDS to registers. + * + * The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C + * between them in such a way that threads from the same group need the same chunk of either + * matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the + * whole chunk from LDS to its own register space. + * Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size` + * of the chunk, and then share that data with other threads from the same lane group. + * + * Assumptions coming from the usage of DPP8: + * 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or + * `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` - + * - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either + * matrix A or B; + * - based on these values we determine which matrix to share. + * 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or + * `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) - + * - we have to make sure that the data to split is divisible by the number of + * threads in the group. + * + * General algorithm: + * C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] + * A and B are visible to the whole block, C is distributed among each thread + * Assume: + * 1. A: + * 1. ABlockDesc_BK0_BM_BK1 is known at compile-time + * 2. ABlockBuffer is DynamicBuffer + * 2. B: + * 1. BBlockDesc_BK0_BN_BK1 is known at compile-time + * 2. BBlockBuffer is DynamicBuffer + * 3. C: + * 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time + * 2. CThreadBuffer is StaticBuffer + * 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 + */ +template + typename BM10BN10ThreadClusterBN10Xs, // Sequence + index_t AThreadCopyScalarPerVector_BM11, + index_t BThreadCopyScalarPerVector_BN11, + typename enable_if::type = false> +struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0 +{ + using AIndex = MultiIndex<4>; + using BIndex = MultiIndex<4>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); + static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); + static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); + static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); + + static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; + static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; + + static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; + static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; + + static constexpr index_t BM11 = BM1PerThreadBM11; + static constexpr index_t BN11 = BN1PerThreadBN11; + + static constexpr index_t BM1 = BM100 * BM101 * BM11; + static constexpr index_t BN1 = BN100 * BN101 * BN11; + + static constexpr index_t BM0 = BM / BM1; + static constexpr index_t BN0 = BN / BN1; + + // We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all + // threads in a lane group need the same chunk of B or A matrices and we can share them using + // DPP. + static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size); + static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false; + static constexpr bool ShareA = !ShareB; + + // If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`, + // respectively) elements, so we split them between threads in lane group so each thread loads + // less data from LDS. + static constexpr index_t BM1PerThread = + ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11; + static constexpr index_t BN1PerThread = + ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11; + + __host__ __device__ static constexpr auto + MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) + { + const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor( + a_block_desc_bk0_bm_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_block_bk0_bm0_bm1_bk1; + } + + __host__ __device__ static constexpr auto + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) + { + const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor( + b_block_desc_bk0_bn_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_block_desc_bk0_bn0_bn1_bk1; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM, BN] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM0, BM1, BN0, BN1] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; + } + + __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() + { + return Sequence{}; + } + + static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = + MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); + + static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); + + public: + __device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0() + : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id())}, + a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()}, + b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()} + { + static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && + BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); + + static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == + BBlockDesc_BK0_BN_BK1{}.GetLength(I0), + "wrong! K dimension not consistent"); + + static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && + BM10BN10ThreadClusterBN10Xs::Size() == 2, + "wrong!"); + } + + __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) + { + // lower: [BM0, BM1, BN0, BN1] + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + constexpr auto adaptor0 = + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); + + // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // upper: [Tid, BM0, BM11, BN0, BN11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), + make_pass_through_transform(BM0), + make_pass_through_transform(BM11), + make_pass_through_transform(BN0), + make_pass_through_transform(BN11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); + } + + __device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1() + { + const auto offsetBM0 = c_thread_origin_data_idx_[I0]; + // If sharing matrix A, we need a separate BM1 offset for each thread in lane group. + const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] + + dpp8::get_thread_idx_in_lane_group() * BM1PerThread + : c_thread_origin_data_idx_[I1]; + return make_tuple(0, offsetBM0, offsetBM1, 0); + } + + __device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1() + { + const auto offsetBN0 = c_thread_origin_data_idx_[I2]; + // If sharing matrix B, we need a separate BN1 offset for each thread in lane group. + const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] + + dpp8::get_thread_idx_in_lane_group() * BN1PerThread + : c_thread_origin_data_idx_[I3]; + return make_tuple(0, offsetBN0, offsetBN1, 0); + } + + template + __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + auto a_thread_buf = make_static_buffer( + a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); + + constexpr auto threadwise_contraction = + ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + FloatA, + FloatB, + FloatC, + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + CThreadDesc_BM0_BM11_BN0_BN11, + Sequence, + Sequence<1, BM1PerThreadBM11>, + Sequence<1, BN1PerThreadBN11>, + ShareA>{}; + + static_for<0, BN0, 1>{}([&](auto bn0) { + static_for<0, BM0, 1>{}([&](auto bm0) { + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, bm0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, bn0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(bm0, I0, bn0, I0)); + + static_for{}([&](auto bk0) { + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, bm0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, bn0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(bm0, I0, bn0, I0)); + }); + }); + }); + } + + private: + // A[BK0, BM0, BM1, BK1] + static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{}, Number{})); + + // B[BK0, BN0, BN1, BK1] + static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< + FloatA, + FloatA, + decltype(a_block_desc_bk0_bm0_bm1_bk1_), + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< + FloatB, + FloatB, + decltype(b_block_desc_bk0_bn0_bn1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index b3caa3214a..f23404a1d7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -11,7 +11,7 @@ namespace ck { // C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] -// A and B are visable to the whole block, C is distributed among each thread +// A and B are visible to the whole block, C is distributed among each thread // Assume: // 1. A: // 1. ABlockDesc_BK0_BM_BK1 is known at compile-time diff --git a/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp b/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp new file mode 100644 index 0000000000..16ca582b89 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct GemmDlAlgorithm +{ + Default, // Uses DOT vector instructions + Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index 13e9f96910..eedf384cd9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -11,6 +11,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/host_utility/device_prop.hpp" @@ -59,6 +60,7 @@ template < typename CThreadTransferSrcDstAccessOrder, index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferDstScalarPerVector, + GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default, enable_if_t< is_same_v && is_same_v && @@ -236,7 +238,8 @@ struct DeviceGemmDl : public DeviceGemm; + CThreadTransferDstScalarPerVector, + GemmDlAlg>; using AGridDesc_K0_M0_M1_K1 = decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); @@ -372,7 +375,8 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - true>; + true, + GemmDlAlg>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -398,7 +402,8 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - false>; + false, + GemmDlAlg>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -424,7 +429,8 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - true>; + true, + GemmDlAlg>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -450,7 +456,8 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - false>; + false, + GemmDlAlg>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -485,6 +492,16 @@ struct DeviceGemmDl : public DeviceGemm +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" +#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t K1, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGemmDlDpp8 : public DeviceGemmDl + +{ + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmDlDpp8" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index d46aea5e22..1d1bb6ed2d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -7,9 +7,11 @@ #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" @@ -17,6 +19,8 @@ namespace ck { +using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm; + template + bool HasDoubleTailKBlockLoop, + GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -38,6 +43,13 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { +// DPP8 is currently only supported on gfx1030 +#if !defined(__gfx1030__) + if(GemmDlAlg == GemmDlAlgorithm::Dpp8) + { + return; + } +#endif constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -88,7 +100,8 @@ template + index_t CThreadTransferDstScalarPerVector, + GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default> struct GridwiseGemmDl_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; @@ -244,6 +257,45 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 c_grid_desc_m_n); } + template + __host__ __device__ static constexpr auto GetBlockwiseGemm() + { + if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8) + { + return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + ABlockDesc_BK0_BM_BK1, + BBlockDesc_BK0_BN_BK1, + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + } + else + { + return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + ABlockDesc_BK0_BM_BK1, + BBlockDesc_BK0_BN_BK1, + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + } + } + using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using CGridDesc_M0_M10_M11_N0_N10_N11 = @@ -274,7 +326,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - // HACK: this force index data into SGPR + // HACK: this forces index data into SGPR const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); @@ -372,20 +424,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register const auto blockwise_gemm = - BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - decltype(a_k0_m_k1_block_desc), - decltype(b_k0_n_k1_block_desc), - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; + GetBlockwiseGemm(); constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); @@ -472,7 +511,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); - // LDS doubel buffer: load next data from device mem + // LDS double buffer: load next data from device mem a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); @@ -992,7 +1031,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step); - // LDS doubel buffer: load next data from device mem + // LDS double buffer: load next data from device mem a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp new file mode 100644 index 0000000000..d0d214381d --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/inner_product_dpp8.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +/** + * Threadwise contraction using dot instructions with DPP8 modifier. + * + * Assumptions: + * 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1` + * are known at compile-time; + * 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time; + * 3. `TM0` is equal to 1 and `TN0` is equal to 1; + * 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by + * the size of the lane group (`dpp8::lane_group_size`). + */ +template ::type = false> +struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t TK0 = TKLengths{}[I0]; + static constexpr index_t TK1 = TKLengths{}[I1]; + static constexpr index_t TM0 = TMLengths{}[I0]; + static constexpr index_t TM1 = TMLengths{}[I1]; + static constexpr index_t TN0 = TNLengths{}[I0]; + static constexpr index_t TN1 = TNLengths{}[I1]; + + static_assert(TM0 == 1 && TN0 == 1); + + static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) || + (!ShareA && TN1 % dpp8::lane_group_size == 0)); + static constexpr index_t shared_elems_per_lane = + ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size; + + __device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + { + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, TK0, 1>{}([&](auto tk0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN1, 1>{}([&](auto tn1) { + vector_type a_vec; + vector_type b_vec; + + static_for<0, TK1, 1>{}([&](auto tk1) { + constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1; + constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1)); + + constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane; + constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1)); + + a_vec.template AsType()(tk1) = a_buf[Number{}]; + b_vec.template AsType()(tk1) = b_buf[Number{}]; + }); + + using a_vector_t = typename vector_type::type; + using b_vector_t = typename vector_type::type; + + constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(0, tm1, 0, tn1)); + + constexpr int src_lane = + ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size + : (tn1 / shared_elems_per_lane) % dpp8::lane_group_size; + + dpp8::inner_product_dpp( + a_vec.template AsType()[I0], + b_vec.template AsType()[I0], + c_buf(Number{})); + }); + }); + }); + } +}; + +} // namespace ck diff --git a/include/ck/utility/amd_gemm_dpp.hpp b/include/ck/utility/amd_gemm_dpp.hpp new file mode 100644 index 0000000000..8d6c7eede9 --- /dev/null +++ b/include/ck/utility/amd_gemm_dpp.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_gemm_dpp.hpp" + +namespace ck { + +namespace dpp8 { + +/// Number of lanes that can share data using DPP8 modifiers. +constexpr index_t lane_group_size = 8; + +__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; } +__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; } + +} // namespace dpp8 + +} // namespace ck diff --git a/include/ck/utility/inner_product_dpp8.hpp b/include/ck/utility/inner_product_dpp8.hpp new file mode 100644 index 0000000000..ccd7a4e628 --- /dev/null +++ b/include/ck/utility/inner_product_dpp8.hpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "amd_gemm_dpp.hpp" +#include "data_type.hpp" +#include "type_convert.hpp" + +namespace ck { + +namespace dpp8 { + +template +__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); + +// clang-format off +template <> +__device__ void inline_v_dot2c_dpp8_instr<0>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[0, 0, 0, 0, 0, 0, 0, 0]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<1>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[1, 1, 1, 1, 1, 1, 1, 1]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<2>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[2, 2, 2, 2, 2, 2, 2, 2]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<3>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[3, 3, 3, 3, 3, 3, 3, 3]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<4>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[4, 4, 4, 4, 4, 4, 4, 4]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<5>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[5, 5, 5, 5, 5, 5, 5, 5]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<6>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[6, 6, 6, 6, 6, 6, 6, 6]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<7>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[7, 7, 7, 7, 7, 7, 7, 7]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +// clang-format on + +/** + * Dot product of two vectors using `v_dot` instruction with DPP8 submitted as inline assembly. + */ +template +__device__ void inline_v_dot2c_dpp8(const half2_t& a, const half2_t& b, float& c) +{ + static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size, + "DPP8 src broadcast lane out of range <0, 7>."); + if constexpr(ShareA) + { + inline_v_dot2c_dpp8_instr(a, b, c); + } + else + { + inline_v_dot2c_dpp8_instr(b, a, c); + } +} + +/** + * DPP8 instrinsics expects to get an integer mask, hardcoding integers for specific broadcast + * patters. + */ +constexpr std::array IntrinsicMaskDpp8 = { + 0, // 0, 0, 0, 0, 0, 0, 0, 0 + 2396745, // 1, 1, 1, 1, 1, 1, 1, 1 + 4793490, // 2, 2, 2, 2, 2, 2, 2, 2 + 7190235, // 3, 3, 3, 3, 3, 3, 3, 3 + 9586980, // 4, 4, 4, 4, 4, 4, 4, 4 + 11983725, // 5, 5, 5, 5, 5, 5, 5, 5 + 14380470, // 6, 6, 6, 6, 6, 6, 6, 6 + 16777215, // 7, 7, 7, 7, 7, 7, 7, 7 +}; + +/** + * Returns DPP8 sel modifier as an integer required for the intrinsic instruction. + */ +template +constexpr int get_dpp_sel_mask_broadcast() +{ + static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size, + "DPP8 src broadcast lane out of range <0, 7>."); + return IntrinsicMaskDpp8[SrcLaneIdx]; +} + +template +__device__ void intrinsic_fdot2_impl(const half2_t& a, const half2_t& b, float& c) +{ + constexpr int sel_mask = get_dpp_sel_mask_broadcast(); + const half2_t val_from_other_lane = + bit_cast(__builtin_amdgcn_mov_dpp8(bit_cast(a), sel_mask)); + c = __builtin_amdgcn_fdot2(val_from_other_lane, b, c, false); +} + +/** + * Dot product of two vectors using `v_dot` instruction with DPP8 submitted using intrinsics. + */ +template +__device__ void intrinsic_fdot2(const half2_t& a, const half2_t& b, float& c) +{ + if constexpr(ShareA) + { + intrinsic_fdot2_impl(a, b, c); + } + else + { + intrinsic_fdot2_impl(b, a, c); + } +} + +/** + * Dot product of two input vectors `a`, `b` using `v_dot` instructions with DPP modifier. + * + * DPP modifier allows us to share one of the vectors between lanes in a lane group. + * When `ShareA` is set, instruction uses vector `a` from lane `SrcLaneIdx` from the same + * lane group (8 lanes per lane group in DPP8). When `ShareA` is not set, vector `b` is shared. + * Note that all the threads in a lane group uses the same vector - broadcast pattern. + * + * `SrcLaneIdx` must be in range from 0 to 7. + */ +template +__device__ void inner_product_dpp(const TA& a, const TB& b, TC& c) +{ +#if CK_USE_AMD_V_DOT_DPP8_INLINE_ASM + inline_v_dot2c_dpp8(a, b, c); +#else + intrinsic_fdot2(a, b, c); +#endif +} + +} // namespace dpp8 + +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index 538ca758f7..06197d9f46 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -23,6 +23,11 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( std::vector>>& @@ -33,6 +38,11 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( std::vector>>& @@ -43,6 +53,11 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( std::vector>>& @@ -53,6 +68,11 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( std::vector>>& @@ -354,6 +374,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } @@ -364,6 +385,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); @@ -375,6 +397,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); } @@ -385,6 +408,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); } diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 266a5d4b50..5162617b18 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -31,6 +31,10 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp) endif() list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..53fc74a392 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..35d46b32f5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..22e2e5f398 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0 ,3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..49e9416c81 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0 ,3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck