mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
[CK_TILE] FMHA bwd Support hdim as a Multiple of 32 (#2130)
* Fix shuffle_tile
* Add fmha bwd d160
* CHANGELOG
* Use static_cast
* Update
---------
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
[ROCm/composable_kernel commit: 1926cd0cb8]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -533,6 +533,26 @@ struct tile_distribution_encoding
|
||||
}
|
||||
};
|
||||
|
||||
template <typename encoding, typename shuffle>
|
||||
class tile_distribution_encoding_shuffle;
|
||||
template <typename encoding, index_t... shuffle>
|
||||
class tile_distribution_encoding_shuffle<encoding, sequence<shuffle...>>
|
||||
{
|
||||
template <typename Ys2RHs>
|
||||
using shuffled = sequence<(Ys2RHs::template get<shuffle>())...>;
|
||||
|
||||
public:
|
||||
using type = tile_distribution_encoding<typename encoding::RsLengths,
|
||||
typename encoding::HsLengthss,
|
||||
typename encoding::Ps2RHssMajor,
|
||||
typename encoding::Ps2RHssMinor,
|
||||
shuffled<typename encoding::Ys2RHsMajor>,
|
||||
shuffled<typename encoding::Ys2RHsMinor>>;
|
||||
};
|
||||
template <typename encoding, typename shuffle>
|
||||
using tile_distribution_encoding_shuffle_t =
|
||||
typename tile_distribution_encoding_shuffle<encoding, shuffle>::type;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename OuterDstr, typename InnerDstr>
|
||||
|
||||
Reference in New Issue
Block a user