mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4294 (commit 6601702)
Cleanup and refactoring related to tile loading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Cleanup and refactoring done while implementing mixed precision for fp16/bf16 x fp8 Key changes: - Renamed load_interleaved_pk_type.hpp to load_and_convert_tile.hpp and refactored the API to use consistent naming conventions - Updated load_tile_transpose functions to use output parameters instead of return values for consistency - Removed unused variable declarations and simplified type deduction logic - Define load_tile_with_elementwise to use tuple types explicitly for clarity ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [x] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [X] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
0438ab1b79
commit
95dc496d30
@@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
|
||||
* is additionally applied during a single read.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
template <typename... TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
|
||||
CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
// TODO: Tile windows should works with unknow number of params
|
||||
// Load element_wise API works only when the input typle is a tuple-tyupe
|
||||
return tile_window[number<0>{}].load(
|
||||
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
// TODO: Tile windows should work with unknown number of params
|
||||
// Load element_wise API works only when the input type is a tuple-type
|
||||
return tile_windows[number<0>{}].load(
|
||||
tile_windows, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
|
||||
@@ -85,12 +85,12 @@ template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -131,7 +131,7 @@ template <typename T,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
CK_TILE_DEVICE void load_tile_raw(T& tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
|
||||
@@ -374,6 +374,7 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
@@ -381,18 +382,19 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param out_tensor A statically distributed tensor containing the transposed tile
|
||||
* data.
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* @param offset The offset (in elements) added to the base address before
|
||||
* indexing.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
@@ -402,21 +404,28 @@ template <
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
CK_TILE_DEVICE void load_tile_transpose_with_offset(
|
||||
DistributedTensor_& out_tensor,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window,
|
||||
index_t offset)
|
||||
{
|
||||
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};
|
||||
|
||||
// Check that the tile distribution of out_tensor is the expected one for transposed loads.
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
static_assert(std::is_same_v<decltype(make_static_tile_distribution(OutTileDstrEncode{})),
|
||||
remove_cvref_t<decltype(output_distr)>>);
|
||||
|
||||
// Check that the datatype of out_tensor matches that of the bottom tensor view.
|
||||
static_assert(std::is_same_v<typename DistributedTensor_::DataType,
|
||||
typename BottomTensorView_::DataType>);
|
||||
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
@@ -443,8 +452,6 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
number<iAccess>{},
|
||||
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
|
||||
});
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -456,6 +463,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
@@ -463,16 +471,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param out_tensor A statically distributed tensor containing the transposed tile
|
||||
* data.
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* indexing.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE void
|
||||
load_tile_transpose(DistributedTensor_& out_tensor,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
|
||||
}
|
||||
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
@@ -489,7 +518,15 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
return load_tile_transpose_with_offset(tile_window, 0);
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
|
||||
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -190,11 +190,11 @@ struct tile_window_with_static_distribution
|
||||
* The same thread, during vectorized reading, accesses the same set of
|
||||
* data from A0, A1, A2, … AN.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
template <typename... TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
|
||||
CK_TILE_DEVICE auto load(const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
@@ -202,7 +202,7 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
load(dst_tensor,
|
||||
tile_window,
|
||||
tile_windows,
|
||||
elementwise,
|
||||
number<i_access_unsupport_>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
@@ -210,12 +210,12 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
typename TileWindow_,
|
||||
typename... TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
|
||||
const TileWindow_& tile_window,
|
||||
const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
@@ -226,14 +226,14 @@ struct tile_window_with_static_distribution
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
constexpr auto sizeOfTuple = TileWindow_::size();
|
||||
constexpr auto sizeOfTuple = remove_cvref_t<decltype(tile_windows)>::size();
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
|
||||
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
|
||||
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
@@ -244,7 +244,7 @@ struct tile_window_with_static_distribution
|
||||
// read from bottom tensor
|
||||
const auto idx_vec_value = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return tile_window[number<jj>{}]
|
||||
return tile_windows[number<jj>{}]
|
||||
.get_bottom_tensor_view()
|
||||
.template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
|
||||
Reference in New Issue
Block a user