mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add support for NGCHW in grouped conv bwd wei (#1491)
* Add support for NGCHW in grouped conv bwd wei * Comments fixes * navi fixes * Update function names
This commit is contained in:
@@ -41,6 +41,55 @@ __global__ void
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
typename InAGridDescTuple,
|
||||
typename InBGridDescTuple,
|
||||
typename OutAGridDescTuple,
|
||||
typename OutBGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename Block2TileMapA,
|
||||
typename Block2TileMapB,
|
||||
typename ElementwiseOperation>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
|
||||
const InBGridDescTuple in_grid_desc_tuple_b,
|
||||
const OutAGridDescTuple out_grid_desc_tuple_a,
|
||||
const OutBGridDescTuple out_grid_desc_tuple_b,
|
||||
const InDataTypePointerTuple p_in_global_tuple_a,
|
||||
const InDataTypePointerTuple p_in_global_tuple_b,
|
||||
const OutDataTypePointerTuple p_out_global_tuple_a,
|
||||
const OutDataTypePointerTuple p_out_global_tuple_b,
|
||||
const Block2TileMapA block_2_tile_map_a,
|
||||
const Block2TileMapB block_2_tile_map_b,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const index_t a_grid_size)
|
||||
{
|
||||
if(get_block_1d_id() < a_grid_size)
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
|
||||
out_grid_desc_tuple_a,
|
||||
p_in_global_tuple_a,
|
||||
p_out_global_tuple_a,
|
||||
block_2_tile_map_a,
|
||||
elementwise_op,
|
||||
get_block_1d_id());
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
|
||||
out_grid_desc_tuple_b,
|
||||
p_in_global_tuple_b,
|
||||
p_out_global_tuple_b,
|
||||
block_2_tile_map_b,
|
||||
elementwise_op,
|
||||
get_block_1d_id() - a_grid_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
typename InGridDescTuple,
|
||||
typename OutGridDescTuple,
|
||||
@@ -133,7 +182,8 @@ struct GridwiseElementwise
|
||||
const InDataTypePointerTuple& p_in_global_tuple,
|
||||
const OutDataTypePointerTuple& p_out_global_tuple,
|
||||
const Block2TileMap& block_2_tile_map,
|
||||
const ElementwiseOperation& elementwise_op)
|
||||
const ElementwiseOperation& elementwise_op,
|
||||
const index_t block_id = get_block_1d_id())
|
||||
{
|
||||
|
||||
constexpr auto src_datas = generate_tuple(
|
||||
@@ -169,7 +219,7 @@ struct GridwiseElementwise
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
|
||||
|
||||
const index_t m0_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
|
||||
|
||||
Reference in New Issue
Block a user