mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Share partition index across threads and specify offset in load_tile()/async_load_tile()/load_tile_transpose() (#2905)
* Allow sharing partition index across threads * Fix typo PartitoinIndex -> PartitionIndex * Remove C++20 'requires' usages * Add missing template arguments * Fix load_tile() overload ambiguity issue * Use SFINAE to exclude invalid arguments * Add additional offset parameter to the async_load_tile() * Remove async_load_tile() default argument to avoid ambiguity * Extract tile_window coordinate compute logic as method * Use warp-shared LDS base address in tile_window::async_load() * Add constraint to tile_window::load() templates * Fix wrong type traits is_class_v<> usages * Add missing constraint to async_load_tile() * Add missing tile_window::load() overload * Add more constraint to avoid load_tile() call ambiguity * Rename ParitionIndex as ReplacementPartitionIndex * Update pre_computed_warp_coords_ in move_extended() * Fix inconsistency between template parameters and documentation * Allow specifying pre-computed parition index * Add type straits is_sequence<> & is_tile_distribution<> * Add type straits is_tensor_view<> * Add type constraints to make_tile_window() templates * Allow passing partition_index to set_tile_if() * Allow specifying partition_index to store_tile() * Add missing template parameter of replace_bottom_tensor_view() * Allow passing partition_index to Default2DEpilogue * Make get_partition_index() public * Add _with_offset() postfix to avoid resolution error * Remove ReplacementPartitionIndex template param * Add missing comments * Add load_tile_transpose_with_offset() overload
This commit is contained in:
@@ -214,6 +214,17 @@ CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
|
||||
printf(">");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct is_sequence : std::false_type
|
||||
{
|
||||
};
|
||||
template <index_t... Is>
|
||||
struct is_sequence<sequence<Is...>> : std::true_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
inline constexpr bool is_sequence_v = is_sequence<T>::value;
|
||||
|
||||
namespace impl {
|
||||
template <typename T, T... Ints>
|
||||
struct __integer_sequence;
|
||||
|
||||
@@ -17,6 +17,19 @@
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
|
||||
template <typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load_with_offset(
|
||||
offset, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename TileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
@@ -49,6 +62,23 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
|
||||
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
|
||||
template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor_>> &&
|
||||
std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile,
|
||||
const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load_with_offset(
|
||||
offset, dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
@@ -112,6 +142,23 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>> &&
|
||||
std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.async_load_with_offset(
|
||||
offset, lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
@@ -121,8 +168,8 @@ CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.async_load(
|
||||
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
return async_load_tile_with_offset(
|
||||
lds_tile, tile_window, 0, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
|
||||
@@ -381,6 +381,8 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* @param offset The offset (in elements) added to the base address before
|
||||
* indexing.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
@@ -399,18 +401,19 @@ template <
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto
|
||||
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window,
|
||||
index_t offset)
|
||||
{
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose<Policy>();
|
||||
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
|
||||
@@ -443,4 +446,49 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
|
||||
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* This function is intended for use with statically distributed tensor tiles, where the input
|
||||
* and output tile distributions differ due to the transpose operation. It ensures that the
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
* @tparam NumCoord The number of coordinates (dimensions).
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* indexing.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto
|
||||
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
return load_tile_transpose_with_offset(tile_window, 0);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -155,11 +155,11 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTi
|
||||
|
||||
// get X indices from tuple of tile_distributed_index<>
|
||||
template <typename StaticTileDistribution, typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
|
||||
DistributedIndices distributed_indices)
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(
|
||||
StaticTileDistribution tile_distribution,
|
||||
DistributedIndices distributed_indices,
|
||||
decltype(get_partition_index(tile_distribution)) partition_index)
|
||||
{
|
||||
const auto partition_index = detail::get_partition_index(tile_distribution);
|
||||
constexpr auto y_indices =
|
||||
tile_distribution.get_y_indices_from_distributed_indices(distributed_indices);
|
||||
|
||||
@@ -170,6 +170,16 @@ get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
|
||||
return x_coord.get_bottom_index();
|
||||
}
|
||||
|
||||
// get X indices from tuple of tile_distributed_index<>
|
||||
template <typename StaticTileDistribution, typename DistributedIndices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
|
||||
DistributedIndices distributed_indices)
|
||||
{
|
||||
return get_x_indices_from_distributed_indices(
|
||||
tile_distribution, distributed_indices, get_partition_index(tile_distribution));
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
|
||||
@@ -192,6 +202,29 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
|
||||
DataType value,
|
||||
XIndicesPredicate predicate,
|
||||
decltype(get_partition_index(std::declval<StaticTileDistribution>())) partition_index)
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
StaticTileDistribution{}, distributed_indices, partition_index);
|
||||
|
||||
if(predicate(x_indices))
|
||||
{
|
||||
out_tensor(distributed_indices) = value;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// this function used inside span loop over
|
||||
template <typename YLengths, index_t XUnpacks>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number<XUnpacks>)
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
@@ -38,6 +39,31 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr,
|
||||
partition_index);
|
||||
|
||||
tile_window.store(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
@@ -61,6 +87,31 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void
|
||||
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index)
|
||||
{
|
||||
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
tile_window_tmp.get_window_lengths(),
|
||||
tile_window_tmp.get_window_origin(),
|
||||
tile_dstr,
|
||||
partition_index);
|
||||
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
|
||||
@@ -444,6 +444,21 @@ struct null_tensor_view
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_tensor_view : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename BufferView, typename TensorDesc, memory_operation_enum DstInMemOp>
|
||||
struct is_tensor_view<tensor_view<BufferView, TensorDesc, DstInMemOp>> : std::true_type
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct is_tensor_view<null_tensor_view> : std::true_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
inline constexpr bool is_tensor_view_v = is_tensor_view<T>::value;
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
|
||||
@@ -17,13 +17,11 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
template <typename Distribution>
|
||||
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
|
||||
{
|
||||
return Distribution::_get_partition_index();
|
||||
return Distribution::get_partition_index();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// distributed span
|
||||
template <index_t... PartialHsLengths>
|
||||
@@ -91,7 +89,7 @@ struct tile_distribution
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto _get_partition_index()
|
||||
CK_TILE_HOST_DEVICE static auto get_partition_index()
|
||||
{
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
@@ -172,9 +170,9 @@ struct tile_distribution
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename PartitionIndex = decltype(_get_partition_index())>
|
||||
template <typename PartitionIndex = decltype(get_partition_index())>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
|
||||
calculate_index(const PartitionIndex& ps_idx = get_partition_index()) const
|
||||
{
|
||||
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
|
||||
const auto window_adaptor_thread_coord_tmp =
|
||||
@@ -230,6 +228,23 @@ struct tile_distribution
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_tile_distribution : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename PsYs2XsAdaptor,
|
||||
typename Ys2DDescriptor,
|
||||
typename StaticTileDistributionEncoding,
|
||||
typename TileDistributionDetail>
|
||||
struct is_tile_distribution<tile_distribution<PsYs2XsAdaptor,
|
||||
Ys2DDescriptor,
|
||||
StaticTileDistributionEncoding,
|
||||
TileDistributionDetail>> : std::true_type
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
inline constexpr bool is_tile_distribution_v = is_tile_distribution<T>::value;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t NDimMax>
|
||||
|
||||
@@ -189,8 +189,7 @@ struct tile_scatter_gather
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_distribution),
|
||||
array<index_t, NDimY>{0}));
|
||||
container_concat(get_partition_index(tile_distribution), array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
@@ -836,7 +835,7 @@ struct tile_scatter_gather
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
|
||||
container_concat(get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_view.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
@@ -67,18 +68,54 @@ struct tile_window_with_static_distribution
|
||||
const typename Base::BottomTensorView& bottom_tensor_view,
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution)
|
||||
const typename Base::TileDstr& tile_distribution,
|
||||
decltype(get_partition_index(tile_distribution)) partition_index)
|
||||
: pre_computed_coords_{}
|
||||
{
|
||||
|
||||
this->window_origin_ = window_origin;
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
this->tile_dstr_ = tile_distribution;
|
||||
this->window_origin_ = window_origin;
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
this->tile_dstr_ = tile_distribution;
|
||||
|
||||
pre_computed_coords_ =
|
||||
prepare_coords(bottom_tensor_view, window_origin, tile_distribution, partition_index);
|
||||
if constexpr(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global)
|
||||
{
|
||||
auto use_lane_id_0 = partition_index;
|
||||
use_lane_id_0[1] = 0;
|
||||
|
||||
pre_computed_warp_coords_ =
|
||||
prepare_coords(bottom_tensor_view, window_origin, tile_distribution, use_lane_id_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(
|
||||
const typename Base::BottomTensorView& bottom_tensor_view,
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution)
|
||||
: tile_window_with_static_distribution(bottom_tensor_view,
|
||||
window_lengths,
|
||||
window_origin,
|
||||
tile_distribution,
|
||||
get_partition_index(tile_distribution))
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution,
|
||||
decltype(get_partition_index(tile_distribution)) partition_index) const
|
||||
{
|
||||
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>
|
||||
coords;
|
||||
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_distribution),
|
||||
array<index_t, Base::NDimY>{0}));
|
||||
container_concat(partition_index, multi_index<Base::NDimY>{0}));
|
||||
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
@@ -105,18 +142,31 @@ struct tile_window_with_static_distribution
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
coords(iCoord) = make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
});
|
||||
|
||||
return coords;
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return load_with_offset(
|
||||
0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_with_offset(index_t offset,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
load_with_offset(offset,
|
||||
dst_tensor,
|
||||
number<i_access_unsupport_>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
@@ -236,6 +286,19 @@ struct tile_window_with_static_distribution
|
||||
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
load_with_offset(
|
||||
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor>>>>
|
||||
CK_TILE_DEVICE auto load_with_offset(index_t offset,
|
||||
DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
@@ -258,7 +321,7 @@ struct tile_window_with_static_distribution
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
bottom_tensor_thread_coord, offset, bool_constant<oob_conditional_check>{});
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
@@ -450,10 +513,12 @@ struct tile_window_with_static_distribution
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
|
||||
CK_TILE_DEVICE auto async_load_with_offset(index_t offset,
|
||||
LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
@@ -472,12 +537,15 @@ struct tile_window_with_static_distribution
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0];
|
||||
auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// Use precomputed window origin
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + window_adaptor_thread_coord.get_bottom_index();
|
||||
window_origin + window_adaptor_warp_coord.get_bottom_index();
|
||||
|
||||
// Use precomputed tensor descriptor
|
||||
const auto lds_coord =
|
||||
@@ -490,7 +558,7 @@ struct tile_window_with_static_distribution
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
number<0>{},
|
||||
offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// Move thread coordinate if not last access
|
||||
@@ -503,18 +571,33 @@ struct tile_window_with_static_distribution
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose() const
|
||||
CK_TILE_DEVICE auto load_transpose(number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return this->template load_transpose_with_offset<Policy>(
|
||||
0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
this->template load_transpose<Policy>(
|
||||
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
this->template load_transpose_with_offset<Policy>(offset,
|
||||
dst_tensor,
|
||||
number<i_access_unsupport_>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
@@ -522,9 +605,10 @@ struct tile_window_with_static_distribution
|
||||
typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset,
|
||||
DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
@@ -550,7 +634,7 @@ struct tile_window_with_static_distribution
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view()
|
||||
.template get_transpose_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0);
|
||||
bottom_tensor_thread_coord, offset);
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto orig_idx_ys = generate_tuple(
|
||||
@@ -862,16 +946,26 @@ struct tile_window_with_static_distribution
|
||||
pre_computed_coords_(iCoord)(I1),
|
||||
step);
|
||||
});
|
||||
|
||||
if constexpr(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global)
|
||||
{
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
|
||||
pre_computed_warp_coords_(iCoord)(I1),
|
||||
step);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&)
|
||||
{
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(this->tile_dstr_),
|
||||
array<index_t, Base::NDimY>{0}));
|
||||
const auto window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(get_partition_index(this->tile_dstr_),
|
||||
array<index_t, Base::NDimY>{0}));
|
||||
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
@@ -908,6 +1002,12 @@ struct tile_window_with_static_distribution
|
||||
// per-thread coordinate for bottom tensor
|
||||
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>
|
||||
pre_computed_coords_;
|
||||
// pre_computed_warp_coords_ exists only in the global memory tile_window
|
||||
std::conditional_t<
|
||||
Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global,
|
||||
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>,
|
||||
std::byte>
|
||||
pre_computed_warp_coords_;
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
@@ -929,6 +1029,27 @@ make_tile_window(const TensorView_& tensor_view,
|
||||
tensor_view, window_lengths, origin, tile_distribution};
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord = 1,
|
||||
typename = std::enable_if_t<is_tensor_view_v<TensorView_> &&
|
||||
is_tile_distribution_v<StaticTileDistribution_>>>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
decltype(get_partition_index(tile_distribution)) partition_index,
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, partition_index};
|
||||
}
|
||||
|
||||
// this version can't be called in a constexpr context
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
@@ -1131,15 +1252,25 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
|
||||
tile_distribution);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
decltype(get_partition_index(tile_distribution)) partition_index)
|
||||
{
|
||||
return make_tile_window(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution,
|
||||
partition_index);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution)
|
||||
{
|
||||
auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution);
|
||||
auto w = make_tile_window(tile_window, tile_distribution);
|
||||
w.init_raw();
|
||||
return w;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user