[rocm-libraries] ROCm/rocm-libraries#4294 (commit 6601702)

Cleanup and refactoring related to tile loading
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Cleanup and refactoring done while implementing mixed precision for
fp16/bf16 x fp8

Key changes:

- Renamed load_interleaved_pk_type.hpp to load_and_convert_tile.hpp and
refactored the API to use consistent naming conventions
- Updated load_tile_transpose functions to use output parameters instead
of return values for consistency
- Removed unused variable declarations and simplified type deduction
logic
- Define load_tile_with_elementwise to use tuple types explicitly for
clarity

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [X] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
SamiAario-AMD
2026-03-02 12:21:44 +00:00
committed by assistant-librarian[bot]
parent 0438ab1b79
commit 95dc496d30
47 changed files with 190 additions and 182 deletions

View File

@@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
* is additionally applied during a single read.
*/
template <typename TileWindow_,
template <typename... TileWindow_,
typename ElementWise_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
// TODO: Tile windows should works with unknow number of params
// Load element_wise API works only when the input typle is a tuple-tyupe
return tile_window[number<0>{}].load(
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
// TODO: Tile windows should work with unknown number of params
// Load element_wise API works only when the input type is a tuple-type
return tile_windows[number<0>{}].load(
tile_windows, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
@@ -85,12 +85,12 @@ template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/**
@@ -131,7 +131,7 @@ template <typename T,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
CK_TILE_DEVICE void load_tile_raw(T& tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,

View File

@@ -374,6 +374,7 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
@@ -381,18 +382,19 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* @param offset The offset (in elements) added to the base address before
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
@@ -402,21 +404,28 @@ template <
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
CK_TILE_DEVICE void load_tile_transpose_with_offset(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window,
index_t offset)
{
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};
// Check that the tile distribution of out_tensor is the expected one for transposed loads.
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
static_assert(std::is_same_v<decltype(make_static_tile_distribution(OutTileDstrEncode{})),
remove_cvref_t<decltype(output_distr)>>);
// Check that the datatype of out_tensor matches that of the bottom tensor view.
static_assert(std::is_same_v<typename DistributedTensor_::DataType,
typename BottomTensorView_::DataType>);
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
@@ -443,8 +452,6 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
number<iAccess>{},
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
});
return out_tensor;
}
/**
@@ -456,6 +463,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
@@ -463,16 +471,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void
load_tile_transpose(DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
}
template <
typename BottomTensorView_,
typename WindowLengths_,
@@ -489,7 +518,15 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
return load_tile_transpose_with_offset(tile_window, 0);
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
return out_tensor;
}
} // namespace ck_tile

View File

@@ -190,11 +190,11 @@ struct tile_window_with_static_distribution
* The same thread, during vectorized reading, accesses the same set of
* data from A0, A1, A2, … AN.
*/
template <typename TileWindow_,
template <typename... TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
CK_TILE_DEVICE auto load(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
@@ -202,7 +202,7 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
load(dst_tensor,
tile_window,
tile_windows,
elementwise,
number<i_access_unsupport_>{},
bool_constant<oob_conditional_check>{});
@@ -210,12 +210,12 @@ struct tile_window_with_static_distribution
}
template <typename DistributedTensor,
typename TileWindow_,
typename... TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
const TileWindow_& tile_window,
const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
@@ -226,14 +226,14 @@ struct tile_window_with_static_distribution
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto sizeOfTuple = TileWindow_::size();
constexpr auto sizeOfTuple = remove_cvref_t<decltype(tile_windows)>::size();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
@@ -244,7 +244,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const auto idx_vec_value = generate_tuple(
[&](auto jj) {
return tile_window[number<jj>{}]
return tile_windows[number<jj>{}]
.get_bottom_tensor_view()
.template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,