mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Implement DPP8 based GEMM for Navi21 (#826)
This commit is contained in:
committed by
GitHub
parent
f60f0a5e03
commit
d4c84256f7
18
include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp
Normal file
18
include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct GemmDlAlgorithm
|
||||
{
|
||||
Default, // Uses DOT vector instructions
|
||||
Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -59,6 +60,7 @@ template <
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default,
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
@@ -236,7 +238,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
CThreadTransferDstScalarPerVector,
|
||||
GemmDlAlg>;
|
||||
|
||||
using AGridDesc_K0_M0_M1_K1 =
|
||||
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
|
||||
@@ -372,7 +375,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DefaultBlock2CTileMap>,
|
||||
true,
|
||||
true>;
|
||||
true,
|
||||
GemmDlAlg>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -398,7 +402,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DefaultBlock2CTileMap>,
|
||||
true,
|
||||
false>;
|
||||
false,
|
||||
GemmDlAlg>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -424,7 +429,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DefaultBlock2CTileMap>,
|
||||
false,
|
||||
true>;
|
||||
true,
|
||||
GemmDlAlg>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -450,7 +456,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DefaultBlock2CTileMap>,
|
||||
false,
|
||||
false>;
|
||||
false,
|
||||
GemmDlAlg>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -485,6 +492,16 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1030")
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
@@ -492,10 +509,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -572,7 +586,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
virtual std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t K1,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
typename M1N1ThreadClusterM1Xs,
|
||||
typename M1N1ThreadClusterN1Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
|
||||
bool> = false>
|
||||
struct DeviceGemmDlDpp8 : public DeviceGemmDl<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
GemmDlAlgorithm::Dpp8>
|
||||
|
||||
{
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmDlDpp8"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< M1PerThread << ", "
|
||||
<< N1PerThread << ", "
|
||||
<< KPerThread
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user