mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Add support for GKCYX grouped conv fwd (#2015)
* Add support for GKCYX grouped conv fwd * fixes * fix * changelog * Fixes
This commit is contained in:
@@ -41,13 +41,16 @@ __global__ void
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
template <typename GridwiseElementwiseFunctorA,
|
||||
typename GridwiseElementwiseFunctorB,
|
||||
typename InAGridDescTuple,
|
||||
typename InBGridDescTuple,
|
||||
typename OutAGridDescTuple,
|
||||
typename OutBGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename InADataTypePointerTuple,
|
||||
typename InBDataTypePointerTuple,
|
||||
typename OutADataTypePointerTuple,
|
||||
typename OutBDataTypePointerTuple,
|
||||
typename Block2TileMapA,
|
||||
typename Block2TileMapB,
|
||||
typename ElementwiseOperation>
|
||||
@@ -55,14 +58,14 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
|
||||
kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a,
|
||||
const InBGridDescTuple in_grid_desc_tuple_b,
|
||||
const OutAGridDescTuple out_grid_desc_tuple_a,
|
||||
const OutBGridDescTuple out_grid_desc_tuple_b,
|
||||
const InDataTypePointerTuple p_in_global_tuple_a,
|
||||
const InDataTypePointerTuple p_in_global_tuple_b,
|
||||
const OutDataTypePointerTuple p_out_global_tuple_a,
|
||||
const OutDataTypePointerTuple p_out_global_tuple_b,
|
||||
const InADataTypePointerTuple p_in_global_tuple_a,
|
||||
const InBDataTypePointerTuple p_in_global_tuple_b,
|
||||
const OutADataTypePointerTuple p_out_global_tuple_a,
|
||||
const OutBDataTypePointerTuple p_out_global_tuple_b,
|
||||
const Block2TileMapA block_2_tile_map_a,
|
||||
const Block2TileMapB block_2_tile_map_b,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
@@ -70,23 +73,23 @@ __global__ void
|
||||
{
|
||||
if(get_block_1d_id() < a_grid_size)
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
|
||||
out_grid_desc_tuple_a,
|
||||
p_in_global_tuple_a,
|
||||
p_out_global_tuple_a,
|
||||
block_2_tile_map_a,
|
||||
elementwise_op,
|
||||
get_block_1d_id());
|
||||
GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
|
||||
out_grid_desc_tuple_a,
|
||||
p_in_global_tuple_a,
|
||||
p_out_global_tuple_a,
|
||||
block_2_tile_map_a,
|
||||
elementwise_op,
|
||||
get_block_1d_id());
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
|
||||
out_grid_desc_tuple_b,
|
||||
p_in_global_tuple_b,
|
||||
p_out_global_tuple_b,
|
||||
block_2_tile_map_b,
|
||||
elementwise_op,
|
||||
get_block_1d_id() - a_grid_size);
|
||||
GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
|
||||
out_grid_desc_tuple_b,
|
||||
p_in_global_tuple_b,
|
||||
p_out_global_tuple_b,
|
||||
block_2_tile_map_b,
|
||||
elementwise_op,
|
||||
get_block_1d_id() - a_grid_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user