Add support for NGCHW in grouped conv bwd wei (#1491)

* Add support for NGCHW in grouped conv bwd wei

* Comments fixes

* navi fixes

* Update function names
This commit is contained in:
Bartłomiej Kocot
2024-09-03 10:52:03 +02:00
committed by GitHub
parent a9b170b541
commit 73b67f290f
24 changed files with 893 additions and 89 deletions

View File

@@ -41,6 +41,55 @@ __global__ void
elementwise_op);
}
template <typename GridwiseElementwiseFunctor,
typename InAGridDescTuple,
typename InBGridDescTuple,
typename OutAGridDescTuple,
typename OutBGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMapA,
typename Block2TileMapB,
typename ElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
const InBGridDescTuple in_grid_desc_tuple_b,
const OutAGridDescTuple out_grid_desc_tuple_a,
const OutBGridDescTuple out_grid_desc_tuple_b,
const InDataTypePointerTuple p_in_global_tuple_a,
const InDataTypePointerTuple p_in_global_tuple_b,
const OutDataTypePointerTuple p_out_global_tuple_a,
const OutDataTypePointerTuple p_out_global_tuple_b,
const Block2TileMapA block_2_tile_map_a,
const Block2TileMapB block_2_tile_map_b,
const ElementwiseOperation elementwise_op,
const index_t a_grid_size)
{
if(get_block_1d_id() < a_grid_size)
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
out_grid_desc_tuple_a,
p_in_global_tuple_a,
p_out_global_tuple_a,
block_2_tile_map_a,
elementwise_op,
get_block_1d_id());
}
else
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
out_grid_desc_tuple_b,
p_in_global_tuple_b,
p_out_global_tuple_b,
block_2_tile_map_b,
elementwise_op,
get_block_1d_id() - a_grid_size);
}
}
template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple,
typename OutGridDescTuple,
@@ -133,7 +182,8 @@ struct GridwiseElementwise
const InDataTypePointerTuple& p_in_global_tuple,
const OutDataTypePointerTuple& p_out_global_tuple,
const Block2TileMap& block_2_tile_map,
const ElementwiseOperation& elementwise_op)
const ElementwiseOperation& elementwise_op,
const index_t block_id = get_block_1d_id())
{
constexpr auto src_datas = generate_tuple(
@@ -169,7 +219,7 @@ struct GridwiseElementwise
Number<NumOutput>{});
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
const index_t m0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);