Conv bwd data multiple d (#404)

* init commit of convnd bwd data

* begin compiling example

* have a first version that produce a right result

* refine device level launch kernel code

* add more instances in example and get right results

* clang-format

* format example file

* add more instances

* fix instances

* adding conv_bwd_data multile_d

* adding conv_bwd_data multile_d

* adding conv_bwd multiple d

* adding conv_bwd multiple d

* adding conv_bwd multiple d

* refactor

* refactor

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* adding conv bwd data multiple d

* refactor

* update conv fwd's bias impl

* refactor

* reorg file

* clean up cmake

* clean

* clean

* clean

Co-authored-by: Chao Liu <lc.roy86@gmail.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
Shaojie WANG
2022-09-20 00:25:28 +08:00
committed by GitHub
parent 43c898f6ff
commit 27858374ac
28 changed files with 2262 additions and 1161 deletions

View File

@@ -35,10 +35,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
@@ -166,6 +162,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
@@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// B desc for source in blockwise copy
template <typename BGridDesc_N_K>
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
@@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// E desc for destination in blockwise copy
template <typename EGridDescriptor_M_N>
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const EGridDescriptor_M_N& e_grid_desc_m_n)
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
@@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
template <typename DsGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
@@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
@@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
@@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using DefaultAGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using DefaultBGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using DsGridPointer = decltype(MakeDsGridPointer());
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
@@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{