mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Fix K padding calculation for grouped conv data (#876)
* Fix K padding calculation for grouped conv data
* Restore previous padd for 1x1 specialization
[ROCm/composable_kernel commit: c981f6d033]
This commit is contained in:
@@ -280,6 +280,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
BK1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN>{};
|
||||
|
||||
|
||||
@@ -268,6 +268,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
|
||||
"KPerBlock must be divisible by AK1Value and BK1Value!");
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
|
||||
@@ -164,6 +164,7 @@ template <
|
||||
index_t BK1,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
bool DoPadGemmM,
|
||||
bool DoPadGemmN>
|
||||
struct TransformConvBwdDataToGemm_v1
|
||||
@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1);
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
// A: output tensor
|
||||
@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const auto out_gemmk_gemmm_padded_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(AK1, GemmMPerBlock),
|
||||
make_tuple(GemmKPerBlock, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const auto out_gemmk_gemmm_padded_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(AK1, GemmMPerBlock),
|
||||
make_tuple(GemmKPerBlock, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1);
|
||||
|
||||
// B weight tensor
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const auto wei_gemmk_gemmn_padded_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmk_gemmnraw_grid_desc,
|
||||
make_tuple(BK1, GemmNPerBlock),
|
||||
make_tuple(GemmKPerBlock, GemmNPerBlock),
|
||||
Sequence<true, DoPadGemmN>{});
|
||||
|
||||
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto wei_gemmk_gemm_padded_grid_desc =
|
||||
const auto wei_gemmk_gemmn_padded_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmk_gemmnraw_grid_desc,
|
||||
make_tuple(BK1, GemmNPerBlock),
|
||||
make_tuple(GemmKPerBlock, GemmNPerBlock),
|
||||
Sequence<true, DoPadGemmN>{});
|
||||
|
||||
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
|
||||
|
||||
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemm_padded_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))),
|
||||
wei_gemmk_gemmn_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(
|
||||
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user