mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Grouped Conv Bwd Data out index calculation optimizations (#2917)
* Grouped Conv Bwd Data index calculation optimizations * fixes * refactor instances * gfx12 fixes * temporary disable splitK for gfx12
This commit is contained in:
@@ -1485,7 +1485,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// gfx11 doesn't support float atomic
|
||||
if(ck::is_gfx11_supported() && arg.k_batch_ > 1)
|
||||
// Todo: Enable splitK for gfx12
|
||||
if((ck::is_gfx12_supported() || ck::is_gfx11_supported()) && arg.k_batch_ > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -13,6 +13,14 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
/**
|
||||
* @brief Enable custom tensor transform for convolution backward data output.
|
||||
*
|
||||
* When set to 1, this macro enables a custom transformation of the output tensor
|
||||
* in convolution backward data operations.
|
||||
*/
|
||||
#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1
|
||||
|
||||
template <
|
||||
index_t NDimSpatial,
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
|
||||
@@ -705,6 +713,12 @@ struct TransformConvBwdDataToGemm_v1
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_,
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0
|
||||
// A: output tensor
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
@@ -762,12 +776,6 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
|
||||
@@ -775,8 +783,46 @@ struct TransformConvBwdDataToGemm_v1
|
||||
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
#else
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Ho_, I0, I0),
|
||||
make_pad_transform(Wo_, I0, I0),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_conv_bwd_data_out_transform(N_,
|
||||
Ho_,
|
||||
Wo_,
|
||||
K_,
|
||||
YDot_,
|
||||
XDot_,
|
||||
HTilde_,
|
||||
WTilde_,
|
||||
ConvDilationH_,
|
||||
ConvDilationW_,
|
||||
HTildeSlice,
|
||||
WTildeSlice,
|
||||
YDotSlice,
|
||||
XDotSlice,
|
||||
IHTildeSliceBegin,
|
||||
IWTildeSliceBegin,
|
||||
GcdStrideDilationH_,
|
||||
GcdStrideDilationW_,
|
||||
AK0,
|
||||
AK1,
|
||||
GemmMPerBlock,
|
||||
GemmKPerBlock)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
return out_n_hop_wop_k_grid_desc_final;
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user