mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Implement DPP8 based GEMM for Navi21 (#826)
This commit is contained in:
committed by
GitHub
parent
f60f0a5e03
commit
d4c84256f7
@@ -7,9 +7,11 @@
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
|
||||
@@ -17,6 +19,8 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm;
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
@@ -25,7 +29,8 @@ template <typename GridwiseGemm,
|
||||
typename CGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
bool HasDoubleTailKBlockLoop,
|
||||
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -38,6 +43,13 @@ __global__ void
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
// DPP8 is currently only supported on gfx1030
|
||||
#if !defined(__gfx1030__)
|
||||
if(GemmDlAlg == GemmDlAlgorithm::Dpp8)
|
||||
{
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
@@ -88,7 +100,8 @@ template <index_t BlockSize,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector>
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
|
||||
struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -244,6 +257,45 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_BK0_BM_BK1, typename BBlockDesc_BK0_BN_BK1>
|
||||
__host__ __device__ static constexpr auto GetBlockwiseGemm()
|
||||
{
|
||||
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
|
||||
{
|
||||
return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
ABlockDesc_BK0_BM_BK1,
|
||||
BBlockDesc_BK0_BN_BK1,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
ABlockDesc_BK0_BM_BK1,
|
||||
BBlockDesc_BK0_BN_BK1,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
|
||||
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
|
||||
using CGridDesc_M0_M10_M11_N0_N10_N11 =
|
||||
@@ -274,7 +326,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
const auto c_m0_n0_block_cluster_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
// HACK: this forces index data into SGPR
|
||||
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
|
||||
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
|
||||
|
||||
@@ -372,20 +424,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
GetBlockwiseGemm<decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc)>();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
@@ -472,7 +511,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
|
||||
b_block_slice_copy_step);
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
// LDS double buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
@@ -992,7 +1031,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
|
||||
b_block_slice_copy_step);
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
// LDS double buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user