mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4873 (commit 580ad4f)
[CK] CK Tile improvements and fixes for depthwise merged convolutions forward (#4873) ## Motivation Performance benchmarks showed that old CK's depthwise merged convolutions are much faster than CK Tile's ones. ## Technical Details After investigation it showed up that the requirement that A/CVectorload is a multiple of gemm's rightmost dimension is too strict in case of processing multiple groups, because if tensor is in NHWGC/NHWGK format, then if C/K is equal to 1, we can use vectorloads on the G dimension, which is added by this PR. Filter5x5 specialization was also added, because some models are using it, it's similar to 3x3, the only difference is the window size. This addition was needed, because of the differences of tensor descriptor transformations betweeen CK and CK Tile. In old CK the case of grouped depthwise 5x5 convs was supported via Default specialization, but in CK Tile that case was not working properly. ## Test Plan Performance was tested by our internal test suite, which contains several DL models. ## Test Result Tests results showed significant performance uplift for depthwise(3x3, 5x5) cases
This commit is contained in:
committed by
assistant-librarian[bot]
parent
1dd47118e2
commit
d32d515f64
@@ -509,6 +509,7 @@ struct GroupedConvolutionForwardKernel
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I5 = number<5>();
|
||||
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
@@ -744,8 +745,9 @@ struct GroupedConvolutionForwardKernel
|
||||
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
// Check access per C
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0)
|
||||
// Check access for A tensor
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeA != 0 &&
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -753,6 +755,28 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
return false;
|
||||
}
|
||||
else if(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
if(ConvC != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("ConvC must be equal to 1 for NumGroupsToMerge > 1 to allow "
|
||||
"vector reads on group dimension!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -794,12 +818,30 @@ struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
// Try to read over G
|
||||
if(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Conv K is not a multiple of vector store size for output image!");
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0 ||
|
||||
ConvG % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge to allow "
|
||||
"writing over G dimension");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"ConvK is not a multiple of vector store size for output image!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -813,6 +855,18 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
// currently group merging works only for C == 1 due to tensor transformation
|
||||
// limitations
|
||||
if(ConvC != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("ConvC must be equal to 1 for NumGroupsToMerge > 1 to allow "
|
||||
"vector reads on group dimension!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
|
||||
{
|
||||
|
||||
@@ -660,8 +660,8 @@ struct TransformConvFwdToGemm
|
||||
else
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge),
|
||||
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
@@ -669,26 +669,24 @@ struct TransformConvFwdToGemm
|
||||
in_n_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
make_pass_through_transform(NumGroupsToMerge)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{}));
|
||||
make_pass_through_transform(NumGroupsToMerge)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_n_x_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(X_, C_))),
|
||||
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
|
||||
make_merge_transform(make_tuple(X_))),
|
||||
make_tuple(sequence<0, 2, 3>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
@@ -906,11 +904,10 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
// IsSupported ensures C == 1 to allow reading on G dimension
|
||||
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(
|
||||
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
|
||||
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
@@ -919,12 +916,9 @@ struct TransformConvFwdToGemm
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
make_pass_through_transform(NumGroupsToMerge)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_groups_c_desc,
|
||||
@@ -933,21 +927,15 @@ struct TransformConvFwdToGemm
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5>{},
|
||||
sequence<6>{}));
|
||||
make_pass_through_transform(NumGroupsToMerge)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_groups_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(Y_, X_, C_))),
|
||||
make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3, 6>{}),
|
||||
make_merge_transform(make_tuple(Y_, X_))),
|
||||
make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
@@ -1214,14 +1202,10 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
else
|
||||
{
|
||||
// IsSupported ensures C == 1 to allow reading on G dimension
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStrideTensorA_,
|
||||
DiStride_,
|
||||
HiStride_,
|
||||
WiStride_,
|
||||
GStrideTensorA_,
|
||||
CStrideTensorA_),
|
||||
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
|
||||
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
|
||||
@@ -1231,20 +1215,11 @@ struct TransformConvFwdToGemm
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}));
|
||||
make_pass_through_transform(NumGroupsToMerge)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_desc,
|
||||
@@ -1255,27 +1230,21 @@ struct TransformConvFwdToGemm
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_pass_through_transform(NumGroupsToMerge)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{},
|
||||
sequence<8>{}));
|
||||
sequence<7>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_c_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
|
||||
make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5, 8>{}),
|
||||
make_merge_transform(make_tuple(Z_, Y_, X_))),
|
||||
make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user