[rocm-libraries] ROCm/rocm-libraries#4873 (commit 580ad4f)

[CK] CK Tile improvements and fixes for depthwise merged
 convolutions forward (#4873)

## Motivation
Performance benchmarks showed that old CK's depthwise merged
convolutions are much faster than CK Tile's ones.

## Technical Details
After investigation it showed up that the requirement that A/CVectorload
is a multiple of gemm's rightmost dimension is too strict in case of
processing multiple groups, because if tensor is in NHWGC/NHWGK format,
then if C/K is equal to 1, we can use vectorloads on the G dimension,
which is added by this PR. Filter5x5 specialization was also added,
because some models are using it, it's similar to 3x3, the only
difference is the window size. This addition was needed, because of the
differences of tensor descriptor transformations betweeen CK and CK
Tile. In old CK the case of grouped depthwise 5x5 convs was supported
via Default specialization, but in CK Tile that case was not working
properly.

## Test Plan
Performance was tested by our internal test suite, which contains
several DL models.

## Test Result
Tests results showed significant performance uplift for depthwise(3x3,
5x5) cases
This commit is contained in:
jakpiase
2026-03-01 13:27:18 +00:00
committed by assistant-librarian[bot]
parent 1dd47118e2
commit d32d515f64
2 changed files with 95 additions and 72 deletions

View File

@@ -660,8 +660,8 @@ struct TransformConvFwdToGemm
else
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
make_tuple(N_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_),
number<VectorSizeA>{},
I1);
@@ -669,26 +669,24 @@ struct TransformConvFwdToGemm
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{}));
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
return transform_tensor_descriptor(
in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(X_, C_))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_merge_transform(make_tuple(X_))),
make_tuple(sequence<0, 2, 3>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
@@ -906,11 +904,10 @@ struct TransformConvFwdToGemm
}
else
{
// IsSupported ensures C == 1 to allow reading on G dimension
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_),
number<VectorSizeA>{},
I1);
@@ -919,12 +916,9 @@ struct TransformConvFwdToGemm
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hip_wip_groups_c_desc,
@@ -933,21 +927,15 @@ struct TransformConvFwdToGemm
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5>{},
sequence<6>{}));
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Y_, X_, C_))),
make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3, 6>{}),
make_merge_transform(make_tuple(Y_, X_))),
make_tuple(sequence<0, 2, 4, 5>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}
@@ -1214,14 +1202,10 @@ struct TransformConvFwdToGemm
}
else
{
// IsSupported ensures C == 1 to allow reading on G dimension
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_,
DiStride_,
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_),
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_),
number<VectorSizeA>{},
I1);
@@ -1231,20 +1215,11 @@ struct TransformConvFwdToGemm
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}));
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
@@ -1255,27 +1230,21 @@ struct TransformConvFwdToGemm
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
make_tuple(sequence<0>{},
sequence<1, 2>{},
sequence<3, 4>{},
sequence<5, 6>{},
sequence<7>{},
sequence<8>{}));
sequence<7>{}));
return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5, 8>{}),
make_merge_transform(make_tuple(Z_, Y_, X_))),
make_tuple(sequence<0, 2, 4, 6, 7>{}, sequence<1, 3, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}