mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
[rocm-libraries] ROCm/rocm-libraries#5334 (commit bb5a3c8)
[CK][CK Tile] Improve access for merged groups and remove modulo from xor (#5334) ## Motivation [CK][CK Tile] Improve access for merged groups and remove modulo from xor ## Technical Details - add template parameter to xor if modulo is needed. We don't need modulo for merged groups - use access by m for merged groups for a tensor - ## Test Plan test_grouped_convnd_fwd_tile ## Test Result passed locally ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fd8714aea9
commit
db40d3f517
@@ -513,7 +513,9 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor> ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge > 1,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
|
||||
@@ -885,20 +887,51 @@ struct GroupedConvolutionForwardKernel
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m)
|
||||
{
|
||||
// Step 1: Create tensor view
|
||||
const auto& a_tensor_view = make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
|
||||
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge == 1)
|
||||
{
|
||||
// Access by K
|
||||
// Step 1: Create tensor view
|
||||
const auto& a_tensor_view = make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view =
|
||||
pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_m, 0});
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{block_idx_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Access by M
|
||||
const auto a_desc_reversed = transform_tensor_descriptor(
|
||||
a_desc,
|
||||
make_tuple(make_pass_through_transform(a_desc.get_length(I0)),
|
||||
make_pass_through_transform(a_desc.get_length(I1))),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Step 1: Create tensor view
|
||||
const auto& a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, a_desc_reversed);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& a_pad_view =
|
||||
pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, block_idx_m});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BDescType>
|
||||
|
||||
@@ -108,7 +108,9 @@ struct GroupedConvTraits
|
||||
using OutLayout = OutLayout_;
|
||||
|
||||
// Forward Gemm Layouts
|
||||
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using AsLayoutFwd = std::conditional_t<NumGroupsToMerge == 1,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Data Gemm Layouts
|
||||
|
||||
@@ -518,10 +518,12 @@ struct TransformConvBwdWeightToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
@@ -652,10 +654,12 @@ struct TransformConvBwdWeightToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
@@ -788,10 +792,12 @@ struct TransformConvBwdWeightToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_),
|
||||
make_pass_through_transform(Z_ * Y_ * X_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
|
||||
// Merge To M, N
|
||||
|
||||
@@ -1363,9 +1363,11 @@ struct TransformConvFwdToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(
|
||||
make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
|
||||
// Merge To M, N
|
||||
@@ -1429,9 +1431,11 @@ struct TransformConvFwdToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(
|
||||
make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
|
||||
// Merge To M, N
|
||||
@@ -1496,9 +1500,11 @@ struct TransformConvFwdToGemm
|
||||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(
|
||||
make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
|
||||
// Merge To M, N
|
||||
|
||||
Reference in New Issue
Block a user