mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5334 (commit bb5a3c8)
[CK][CK Tile] Improve access for merged groups and remove modulo from xor (#5334) ## Motivation [CK][CK Tile] Improve access for merged groups and remove modulo from xor ## Technical Details - add template parameter to xor if modulo is needed. We don't need modulo for merged groups - use access by m for merged groups for a tensor - ## Test Plan test_grouped_convnd_fwd_tile ## Test Result passed locally ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fd8714aea9
commit
db40d3f517
@@ -108,7 +108,9 @@ struct GroupedConvTraits
|
||||
using OutLayout = OutLayout_;
|
||||
|
||||
// Forward Gemm Layouts
|
||||
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using AsLayoutFwd = std::conditional_t<NumGroupsToMerge == 1,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Data Gemm Layouts
|
||||
|
||||
@@ -518,10 +518,12 @@ struct TransformConvBwdWeightToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
@@ -652,10 +654,12 @@ struct TransformConvBwdWeightToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
@@ -788,10 +792,12 @@ struct TransformConvBwdWeightToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
|
||||
@@ -1363,9 +1363,11 @@ struct TransformConvFwdToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(
|
||||
make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
|
||||
// Merge To M, N
|
||||
@@ -1429,9 +1431,11 @@ struct TransformConvFwdToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(
|
||||
make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
|
||||
// Merge To M, N
|
||||
@@ -1496,9 +1500,11 @@ struct TransformConvFwdToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(
|
||||
make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
|
||||
// Merge To M, N
|
||||
|
||||
Reference in New Issue
Block a user