mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add grouped conv bwd weight multi d kernel (#1237)
* Add grouped conv bwd weight multi d kernel * Reference fix * Fix cmake files * bwd weight scale only xdl * Fixes * Fix client conv fwd example
This commit is contained in:
@@ -41,6 +41,58 @@ __global__ void
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
typename InGridDescTuple,
|
||||
typename OutGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename Block2TileMap,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumInputs,
|
||||
index_t NumOutputs>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple,
|
||||
const OutGridDescTuple out_grid_desc_tuple,
|
||||
const InDataTypePointerTuple p_in_global_tuple,
|
||||
const OutDataTypePointerTuple p_out_global_tuple,
|
||||
const Block2TileMap block_2_tile_map,
|
||||
const ElementwiseOperation elementwise_op,
|
||||
const index_t batch_count,
|
||||
const std::array<index_t, NumInputs> input_batch_strides,
|
||||
const std::array<index_t, NumOutputs> output_batch_strides)
|
||||
{
|
||||
static_assert(InGridDescTuple::Size() == NumInputs &&
|
||||
InDataTypePointerTuple::Size() == NumInputs);
|
||||
static_assert(OutGridDescTuple::Size() == NumOutputs &&
|
||||
OutDataTypePointerTuple::Size() == NumOutputs);
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
InDataTypePointerTuple p_in_global_with_offset_tuple;
|
||||
OutDataTypePointerTuple p_out_global_with_offset_tuple;
|
||||
|
||||
static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_in_global_with_offset_tuple(i) = p_in_global_tuple.At(i) + input_batch_strides[i] * g_idx;
|
||||
});
|
||||
|
||||
static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_out_global_with_offset_tuple(i) =
|
||||
p_out_global_tuple.At(i) + output_batch_strides[i] * g_idx;
|
||||
});
|
||||
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
|
||||
out_grid_desc_tuple,
|
||||
p_in_global_with_offset_tuple,
|
||||
p_out_global_with_offset_tuple,
|
||||
block_2_tile_map,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
template <typename InGridDescTuple,
|
||||
typename OutGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
|
||||
Reference in New Issue
Block a user