mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Conv bwd data multiple d (#404)
* init commit of convnd bwd data * begin compiling example * have a first version that produce a right result * refine device level launch kernel code * add more instances in example and get right results * clang-format * format example file * add more instances * fix instances * adding conv_bwd_data multile_d * adding conv_bwd_data multile_d * adding conv_bwd multiple d * adding conv_bwd multiple d * adding conv_bwd multiple d * refactor * refactor * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * adding conv bwd data multiple d * refactor * update conv fwd's bias impl * refactor * reorg file * clean up cmake * clean * clean * clean Co-authored-by: Chao Liu <lc.roy86@gmail.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -35,10 +35,6 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename AGridDesc_M_K,
|
||||
typename BGridDesc_N_K,
|
||||
typename DsGridDesc_M_N,
|
||||
typename EGridDesc_M_N,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -166,6 +162,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
@@ -182,6 +179,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
@@ -198,9 +196,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// E desc for destination in blockwise copy
|
||||
template <typename EGridDescriptor_M_N>
|
||||
__host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const EGridDescriptor_M_N& e_grid_desc_m_n)
|
||||
template <typename EGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
|
||||
{
|
||||
const auto M = e_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = e_grid_desc_m_n.GetLength(I1);
|
||||
@@ -219,10 +217,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// Ds desc for source in blockwise copy
|
||||
template <typename DsGridDescriptor_M_N>
|
||||
template <typename DsGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -232,6 +229,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// return block_id to E matrix tile idx (m0, n0) mapping
|
||||
template <typename EGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
|
||||
{
|
||||
@@ -240,7 +238,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2ETileMap>
|
||||
template <typename AGridDesc_M_K,
|
||||
typename BGridDesc_N_K,
|
||||
typename DsGridDesc_M_N,
|
||||
typename EGridDesc_M_N,
|
||||
typename Block2ETileMap>
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
|
||||
const BGridDesc_N_K& b_grid_desc_n_k,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
@@ -314,23 +316,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
using DefaultAGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using DefaultBGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2ETileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap>
|
||||
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
@@ -342,9 +334,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user