Fix K padding calculation for grouped conv data (#876)

* Fix K padding calculation for grouped conv data

* Restore previous padd for 1x1 specialization

[ROCm/composable_kernel commit: c981f6d033]
This commit is contained in:
Bartłomiej Kocot
2023-09-05 17:07:41 +02:00
committed by GitHub
parent 47958ebd07
commit fcafba0fd4
3 changed files with 21 additions and 15 deletions

View File

@@ -280,6 +280,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
BK1,
MPerBlock,
NPerBlock,
KPerBlock,
DoPadGemmM,
DoPadGemmN>{};

View File

@@ -268,6 +268,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);

View File

@@ -164,6 +164,7 @@ template <
index_t BK1,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
bool DoPadGemmN>
struct TransformConvBwdDataToGemm_v1
@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t AK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1);
if constexpr(NDimSpatial == 2)
{
// A: output tensor
@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock),
make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock),
make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t BK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1);
// B weight tensor
if constexpr(NDimSpatial == 2)
{
@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock),
make_tuple(GemmKPerBlock, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmk_gemm_padded_grid_desc =
const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock),
make_tuple(GemmKPerBlock, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemm_padded_grid_desc,
make_tuple(
make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))),
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));