mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Grouped conv backward data GKCYX support (#2029)
* Grouped conv backward data GKCYX support * profiler * Converter * split instances
This commit is contained in:
@@ -93,6 +93,119 @@ __global__ void
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GridwiseElementwiseFunctorA,
|
||||
typename GridwiseElementwiseFunctorB,
|
||||
typename InAGridDescTuple,
|
||||
typename InBGridDescTuple,
|
||||
typename OutAGridDescTuple,
|
||||
typename OutBGridDescTuple,
|
||||
typename InADataTypePointerTuple,
|
||||
typename InBDataTypePointerTuple,
|
||||
typename OutADataTypePointerTuple,
|
||||
typename OutBDataTypePointerTuple,
|
||||
typename Block2TileMapA,
|
||||
typename Block2TileMapB,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumInputsA,
|
||||
index_t NumInputsB,
|
||||
index_t NumOutputsA,
|
||||
index_t NumOutputsB>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_elementwise_batched_dual(
|
||||
const InAGridDescTuple in_grid_desc_tuple_a,
|
||||
const InBGridDescTuple in_grid_desc_tuple_b,
|
||||
const OutAGridDescTuple out_grid_desc_tuple_a,
|
||||
const OutBGridDescTuple out_grid_desc_tuple_b,
|
||||
const InADataTypePointerTuple p_in_global_tuple_a,
|
||||
const InBDataTypePointerTuple p_in_global_tuple_b,
|
||||
const OutADataTypePointerTuple p_out_global_tuple_a,
|
||||
const OutBDataTypePointerTuple p_out_global_tuple_b,
|
||||
const Block2TileMapA block_2_tile_map_a,
|
||||
const Block2TileMapB block_2_tile_map_b,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const index_t a_grid_size,
|
||||
const index_t batch_count_a,
|
||||
const index_t batch_count_b,
|
||||
const std::array<index_t, NumInputsA> input_batch_strides_a,
|
||||
const std::array<index_t, NumInputsB> input_batch_strides_b,
|
||||
const std::array<index_t, NumOutputsA> output_batch_strides_a,
|
||||
const std::array<index_t, NumOutputsB> output_batch_strides_b)
|
||||
{
|
||||
static_assert(InAGridDescTuple::Size() == NumInputsA &&
|
||||
InADataTypePointerTuple::Size() == NumInputsA);
|
||||
static_assert(OutAGridDescTuple::Size() == NumOutputsA &&
|
||||
OutADataTypePointerTuple::Size() == NumOutputsA);
|
||||
static_assert(InBGridDescTuple::Size() == NumInputsB &&
|
||||
InBDataTypePointerTuple::Size() == NumInputsB);
|
||||
static_assert(OutBGridDescTuple::Size() == NumOutputsB &&
|
||||
OutBDataTypePointerTuple::Size() == NumOutputsB);
|
||||
|
||||
const index_t block_id = __builtin_amdgcn_readfirstlane(get_block_1d_id());
|
||||
|
||||
if(block_id < a_grid_size)
|
||||
{
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch);
|
||||
|
||||
InADataTypePointerTuple p_in_global_with_offset_tuple;
|
||||
OutADataTypePointerTuple p_out_global_with_offset_tuple;
|
||||
|
||||
static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_in_global_with_offset_tuple(i) =
|
||||
p_in_global_tuple_a.At(i) +
|
||||
type_convert<long_index_t>(input_batch_strides_a[i]) * g_idx;
|
||||
});
|
||||
|
||||
static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_out_global_with_offset_tuple(i) =
|
||||
p_out_global_tuple_a.At(i) +
|
||||
type_convert<long_index_t>(output_batch_strides_a[i]) * g_idx;
|
||||
});
|
||||
|
||||
GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
|
||||
out_grid_desc_tuple_a,
|
||||
p_in_global_with_offset_tuple,
|
||||
p_out_global_with_offset_tuple,
|
||||
block_2_tile_map_a,
|
||||
elementwise_op,
|
||||
block_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane((get_grid_size() - a_grid_size) / batch_count_b);
|
||||
const index_t g_idx =
|
||||
__builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch);
|
||||
|
||||
InBDataTypePointerTuple p_in_global_with_offset_tuple;
|
||||
OutBDataTypePointerTuple p_out_global_with_offset_tuple;
|
||||
|
||||
static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_in_global_with_offset_tuple(i) =
|
||||
p_in_global_tuple_b.At(i) +
|
||||
type_convert<long_index_t>(input_batch_strides_b[i]) * g_idx;
|
||||
});
|
||||
|
||||
static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_out_global_with_offset_tuple(i) =
|
||||
p_out_global_tuple_b.At(i) +
|
||||
type_convert<long_index_t>(output_batch_strides_b[i]) * g_idx;
|
||||
});
|
||||
|
||||
GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
|
||||
out_grid_desc_tuple_b,
|
||||
p_in_global_with_offset_tuple,
|
||||
p_out_global_with_offset_tuple,
|
||||
block_2_tile_map_b,
|
||||
elementwise_op,
|
||||
block_id - a_grid_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
typename InGridDescTuple,
|
||||
typename OutGridDescTuple,
|
||||
|
||||
Reference in New Issue
Block a user