mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Support multi AB for grouped conv fwd xdl (#1027)
* Support multi AB for grouped conv fwd xdl * Add instances * Add client example * Add example * Add interface test * Minor fixes Minor fixes Minor fixes * Comment fixes * Fixes * Reference fix * Test xdl fixes * Improve multi_ab interface test
This commit is contained in:
@@ -203,7 +203,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
@@ -219,17 +219,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
|
||||
template <typename AsGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
|
||||
MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
|
||||
[&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
|
||||
Number<NumATensor>{});
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
@@ -245,10 +245,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
|
||||
template <typename BsGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
|
||||
MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
|
||||
[&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
|
||||
Number<NumBTensor>{});
|
||||
}
|
||||
|
||||
@@ -288,7 +288,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
// return block_id to E matrix tile idx (m0, n0) mapping
|
||||
template <typename EGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
|
||||
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
|
||||
{
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
|
||||
e_grid_desc_m_n);
|
||||
@@ -591,6 +591,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
|
||||
Number<NumATensor>{});
|
||||
|
||||
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
|
||||
"Src and Dst ScalarPerVector must be the same");
|
||||
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
AsDataType,
|
||||
@@ -619,6 +622,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
|
||||
Number<NumBTensor>{});
|
||||
|
||||
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
|
||||
"Src and Dst ScalarPerVector must be the same");
|
||||
|
||||
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
BsDataType,
|
||||
@@ -1005,9 +1011,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
|
||||
const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
|
||||
|
||||
Reference in New Issue
Block a user