[CK_TILE] Share partition index across threads and specify offset in load_tile()/async_load_tile()/load_tile_transpose() (#2905)

* Allow sharing partition index across threads

* Fix typo PartitoinIndex -> PartitionIndex

* Remove C++20 'requires' usages

* Add missing template arguments

* Fix load_tile() overload ambiguity issue

* Use SFINAE to exclude invalid arguments

* Add additional offset parameter to the async_load_tile()

* Remove async_load_tile() default argument to avoid ambiguity

* Extract tile_window coordinate compute logic as method

* Use warp-shared LDS base address in tile_window::async_load()

* Add constraint to tile_window::load() templates

* Fix wrong type traits is_class_v<> usages

* Add missing constraint to async_load_tile()

* Add missing tile_window::load() overload

* Add more constraint to avoid load_tile() call ambiguity

* Rename ParitionIndex as ReplacementPartitionIndex

* Update pre_computed_warp_coords_ in move_extended()

* Fix inconsistency between template parameters and documentation

* Allow specifying pre-computed parition index

* Add type straits is_sequence<> & is_tile_distribution<>

* Add type straits is_tensor_view<>

* Add type constraints to make_tile_window() templates

* Allow passing partition_index to set_tile_if()

* Allow specifying partition_index to store_tile()

* Add missing template parameter of replace_bottom_tensor_view()

* Allow passing partition_index to Default2DEpilogue

* Make get_partition_index() public

* Add _with_offset() postfix to avoid resolution error

* Remove ReplacementPartitionIndex template param

* Add missing comments

* Add load_tile_transpose_with_offset() overload
This commit is contained in:
Po Yen Chen
2025-11-12 10:26:14 +08:00
committed by GitHub
parent 92c1f4981a
commit 40d2ed0f2a
11 changed files with 441 additions and 58 deletions

View File

@@ -214,6 +214,17 @@ CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
printf(">");
}
template <typename T>
struct is_sequence : std::false_type
{
};
template <index_t... Is>
struct is_sequence<sequence<Is...>> : std::true_type
{
};
template <typename T>
inline constexpr bool is_sequence_v = is_sequence<T>::value;
namespace impl {
template <typename T, T... Ints>
struct __integer_sequence;

View File

@@ -17,6 +17,19 @@
#include "ck_tile/core/tensor/null_tensor.hpp"
namespace ck_tile {
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
template <typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window,
index_t offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load_with_offset(
offset, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename TileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
@@ -49,6 +62,23 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor_>> &&
std::is_class_v<TileWindow_>>>
CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile,
const TileWindow_& tile_window,
index_t offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load_with_offset(
offset, dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,
@@ -112,6 +142,23 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>> &&
std::is_class_v<TileWindow_>>>
CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
index_t offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load_with_offset(
offset, lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
@@ -121,8 +168,8 @@ CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
return async_load_tile_with_offset(
lds_tile, tile_window, 0, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,

View File

@@ -381,6 +381,8 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param tile_window The tile window with static distribution to load and transpose.
* @param offset The offset (in elements) added to the base address before
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
@@ -399,18 +401,19 @@ template <
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window,
index_t offset)
{
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose<Policy>();
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
@@ -443,4 +446,49 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
return out_tensor;
}
/**
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
*
* This function is intended for use with statically distributed tensor tiles, where the input
* and output tile distributions differ due to the transpose operation. It ensures that the
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
* @tparam NumCoord The number of coordinates (dimensions).
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param tile_window The tile window with static distribution to load and transpose.
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
return load_tile_transpose_with_offset(tile_window, 0);
}
} // namespace ck_tile

View File

@@ -155,11 +155,11 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTi
// get X indices from tuple of tile_distributed_index<>
template <typename StaticTileDistribution, typename DistributedIndices>
CK_TILE_HOST_DEVICE constexpr auto
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
DistributedIndices distributed_indices)
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(
StaticTileDistribution tile_distribution,
DistributedIndices distributed_indices,
decltype(get_partition_index(tile_distribution)) partition_index)
{
const auto partition_index = detail::get_partition_index(tile_distribution);
constexpr auto y_indices =
tile_distribution.get_y_indices_from_distributed_indices(distributed_indices);
@@ -170,6 +170,16 @@ get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
return x_coord.get_bottom_index();
}
// get X indices from tuple of tile_distributed_index<>
template <typename StaticTileDistribution, typename DistributedIndices>
CK_TILE_HOST_DEVICE constexpr auto
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
DistributedIndices distributed_indices)
{
return get_x_indices_from_distributed_indices(
tile_distribution, distributed_indices, get_partition_index(tile_distribution));
}
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
CK_TILE_HOST_DEVICE void
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
@@ -192,6 +202,29 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
});
}
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
CK_TILE_HOST_DEVICE void
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
DataType value,
XIndicesPredicate predicate,
decltype(get_partition_index(std::declval<StaticTileDistribution>())) partition_index)
{
constexpr auto out_spans =
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
StaticTileDistribution{}, distributed_indices, partition_index);
if(predicate(x_indices))
{
out_tensor(distributed_indices) = value;
}
});
});
}
// this function used inside span loop over
template <typename YLengths, index_t XUnpacks>
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number<XUnpacks>)

View File

@@ -9,6 +9,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
@@ -38,6 +39,31 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr,
partition_index);
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
@@ -61,6 +87,31 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
tile_window.store_raw(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr,
partition_index);
tile_window.store_raw(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,

View File

@@ -444,6 +444,21 @@ struct null_tensor_view
{
};
template <typename T>
struct is_tensor_view : std::false_type
{
};
template <typename BufferView, typename TensorDesc, memory_operation_enum DstInMemOp>
struct is_tensor_view<tensor_view<BufferView, TensorDesc, DstInMemOp>> : std::true_type
{
};
template <>
struct is_tensor_view<null_tensor_view> : std::true_type
{
};
template <typename T>
inline constexpr bool is_tensor_view_v = is_tensor_view<T>::value;
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
memory_operation_enum DstInMemOp = memory_operation_enum::set,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,

View File

@@ -17,13 +17,11 @@
namespace ck_tile {
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
return Distribution::_get_partition_index();
return Distribution::get_partition_index();
}
} // namespace detail
// distributed span
template <index_t... PartialHsLengths>
@@ -91,7 +89,7 @@ struct tile_distribution
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
CK_TILE_HOST_DEVICE static auto _get_partition_index()
CK_TILE_HOST_DEVICE static auto get_partition_index()
{
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
@@ -172,9 +170,9 @@ struct tile_distribution
}
#endif
template <typename PartitionIndex = decltype(_get_partition_index())>
template <typename PartitionIndex = decltype(get_partition_index())>
CK_TILE_HOST_DEVICE auto
calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
calculate_index(const PartitionIndex& ps_idx = get_partition_index()) const
{
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
const auto window_adaptor_thread_coord_tmp =
@@ -230,6 +228,23 @@ struct tile_distribution
}
};
template <typename T>
struct is_tile_distribution : std::false_type
{
};
template <typename PsYs2XsAdaptor,
typename Ys2DDescriptor,
typename StaticTileDistributionEncoding,
typename TileDistributionDetail>
struct is_tile_distribution<tile_distribution<PsYs2XsAdaptor,
Ys2DDescriptor,
StaticTileDistributionEncoding,
TileDistributionDetail>> : std::true_type
{
};
template <typename T>
inline constexpr bool is_tile_distribution_v = is_tile_distribution<T>::value;
namespace detail {
template <index_t NDimMax>

View File

@@ -189,8 +189,7 @@ struct tile_scatter_gather
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, NDimY>{0}));
container_concat(get_partition_index(tile_distribution), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
@@ -836,7 +835,7 @@ struct tile_scatter_gather
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
container_concat(get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =

View File

@@ -12,6 +12,7 @@
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_window_base.hpp"
#include "ck_tile/core/utility/functional.hpp"
@@ -67,18 +68,54 @@ struct tile_window_with_static_distribution
const typename Base::BottomTensorView& bottom_tensor_view,
const typename Base::WindowLengths& window_lengths,
const typename Base::BottomTensorIndex& window_origin,
const typename Base::TileDstr& tile_distribution)
const typename Base::TileDstr& tile_distribution,
decltype(get_partition_index(tile_distribution)) partition_index)
: pre_computed_coords_{}
{
this->window_origin_ = window_origin;
this->window_lengths_ = window_lengths;
this->bottom_tensor_view_ = bottom_tensor_view;
this->tile_dstr_ = tile_distribution;
this->window_origin_ = window_origin;
this->window_lengths_ = window_lengths;
this->bottom_tensor_view_ = bottom_tensor_view;
this->tile_dstr_ = tile_distribution;
pre_computed_coords_ =
prepare_coords(bottom_tensor_view, window_origin, tile_distribution, partition_index);
if constexpr(Base::BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
{
auto use_lane_id_0 = partition_index;
use_lane_id_0[1] = 0;
pre_computed_warp_coords_ =
prepare_coords(bottom_tensor_view, window_origin, tile_distribution, use_lane_id_0);
}
}
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(
const typename Base::BottomTensorView& bottom_tensor_view,
const typename Base::WindowLengths& window_lengths,
const typename Base::BottomTensorIndex& window_origin,
const typename Base::TileDstr& tile_distribution)
: tile_window_with_static_distribution(bottom_tensor_view,
window_lengths,
window_origin,
tile_distribution,
get_partition_index(tile_distribution))
{
}
CK_TILE_DEVICE constexpr auto
prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view,
const typename Base::BottomTensorIndex& window_origin,
const typename Base::TileDstr& tile_distribution,
decltype(get_partition_index(tile_distribution)) partition_index) const
{
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>
coords;
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, Base::NDimY>{0}));
container_concat(partition_index, multi_index<Base::NDimY>{0}));
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
@@ -105,18 +142,31 @@ struct tile_window_with_static_distribution
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
coords(iCoord) = make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
return coords;
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
return load_with_offset(
0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_with_offset(index_t offset,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
load_with_offset(offset,
dst_tensor,
number<i_access_unsupport_>{},
bool_constant<oob_conditional_check>{});
return dst_tensor;
}
@@ -236,6 +286,19 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
load_with_offset(
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor>>>>
CK_TILE_DEVICE auto load_with_offset(index_t offset,
DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
@@ -258,7 +321,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const vector_t vec_value =
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord, offset, bool_constant<oob_conditional_check>{});
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
@@ -450,10 +513,12 @@ struct tile_window_with_static_distribution
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
CK_TILE_DEVICE auto async_load_with_offset(index_t offset,
LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
@@ -472,12 +537,15 @@ struct tile_window_with_static_distribution
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0];
auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// Use precomputed window origin
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_thread_coord.get_bottom_index();
window_origin + window_adaptor_warp_coord.get_bottom_index();
// Use precomputed tensor descriptor
const auto lds_coord =
@@ -490,7 +558,7 @@ struct tile_window_with_static_distribution
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
number<0>{},
offset,
bool_constant<oob_conditional_check>{});
// Move thread coordinate if not last access
@@ -503,18 +571,33 @@ struct tile_window_with_static_distribution
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
}
});
});
}
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose() const
CK_TILE_DEVICE auto load_transpose(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
return this->template load_transpose_with_offset<Policy>(
0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
}
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
this->template load_transpose<Policy>(
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
this->template load_transpose_with_offset<Policy>(offset,
dst_tensor,
number<i_access_unsupport_>{},
bool_constant<oob_conditional_check>{});
return dst_tensor;
}
@@ -522,9 +605,10 @@ struct tile_window_with_static_distribution
typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset,
DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
@@ -550,7 +634,7 @@ struct tile_window_with_static_distribution
const vector_t vec_value =
this->get_bottom_tensor_view()
.template get_transpose_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0);
bottom_tensor_thread_coord, offset);
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto orig_idx_ys = generate_tuple(
@@ -862,16 +946,26 @@ struct tile_window_with_static_distribution
pre_computed_coords_(iCoord)(I1),
step);
});
if constexpr(Base::BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
{
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_warp_coords_(iCoord)(I1),
step);
});
}
}
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&)
{
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(this->tile_dstr_),
array<index_t, Base::NDimY>{0}));
const auto window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(get_partition_index(this->tile_dstr_),
array<index_t, Base::NDimY>{0}));
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
@@ -908,6 +1002,12 @@ struct tile_window_with_static_distribution
// per-thread coordinate for bottom tensor
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>
pre_computed_coords_;
// pre_computed_warp_coords_ exists only in the global memory tile_window
std::conditional_t<
Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global,
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>,
std::byte>
pre_computed_warp_coords_;
};
// TODO: use strategy
@@ -929,6 +1029,27 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view, window_lengths, origin, tile_distribution};
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1,
typename = std::enable_if_t<is_tensor_view_v<TensorView_> &&
is_tile_distribution_v<StaticTileDistribution_>>>
CK_TILE_DEVICE constexpr auto
make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
decltype(get_partition_index(tile_distribution)) partition_index,
number<NumCoord> = {})
{
return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, partition_index};
}
// this version can't be called in a constexpr context
template <typename TensorView_,
typename WindowLengths_,
@@ -1131,15 +1252,25 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution,
decltype(get_partition_index(tile_distribution)) partition_index)
{
return make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
partition_index);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution);
auto w = make_tile_window(tile_window, tile_distribution);
w.init_raw();
return w;
}