Refactor tile_window.hpp, tile_window_linear.hpp into a CK Tile Hierarchy (#2214)

* window_origin variable now in base class

* abstracted more functions

* consolidated tile_window_static_distribution and tile_window_static_lengths

* clang format

* skeleton code for tile_window and tile_window_linear consolidation

* more abstraction

* moved variables from child to parent

* clang format

* removed comments

* removed debug code

* removed debug code

* abstracting traits WIP

* consolidated traits

* removed comments and clang formatted

[ROCm/composable_kernel commit: 534d4594d0]
This commit is contained in:
Aviral Goel
2025-05-22 01:28:00 -05:00
committed by GitHub
parent 598bf07121
commit 7d47d71bc3
4 changed files with 571 additions and 770 deletions

View File

@@ -54,6 +54,7 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
#include "ck_tile/core/tensor/tile_window_base.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"

View File

@@ -13,6 +13,7 @@
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_window_base.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
@@ -34,166 +35,60 @@ template <typename BottomTensorView_,
typename StaticTileDistribution_,
index_t NumCoord>
struct tile_window_with_static_distribution
: public tile_window_with_tile_dstr_base<
tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>,
BottomTensorView_,
WindowLengths_,
StaticTileDistribution_>
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
using Base = tile_window_with_tile_dstr_base<
tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>,
BottomTensorView_,
WindowLengths_,
StaticTileDistribution_>;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static_assert(NumCoord == 1);
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct load_store_traits
{
private:
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
tile_window_with_static_distribution::
get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_)>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
};
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
static_assert(Base::Traits::NumAccess % NumCoord == 0,
"wrong! # of access is not divisible by NumCoord");
static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord;
CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default;
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
pre_computed_coords_{}
const typename Base::BottomTensorView& bottom_tensor_view,
const typename Base::WindowLengths& window_lengths,
const typename Base::BottomTensorIndex& window_origin,
const typename Base::TileDstr& tile_distribution)
: pre_computed_coords_{}
{
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
this->window_origin_ = window_origin;
this->window_lengths_ = window_lengths;
this->bottom_tensor_view_ = bottom_tensor_view;
this->tile_dstr_ = tile_distribution;
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, NDimY>{0}));
#endif
array<index_t, Base::NDimY>{0}));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using Traits = typename Base::Traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
@@ -204,9 +99,10 @@ struct tile_window_with_static_distribution
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
@@ -214,95 +110,12 @@ struct tile_window_with_static_distribution
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
@@ -314,11 +127,11 @@ struct tile_window_with_static_distribution
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
@@ -334,9 +147,8 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
@@ -344,33 +156,26 @@ struct tile_window_with_static_distribution
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
vec_value
.template get_as<typename Base::DataType>()[j / Traits::PackedSize];
});
#else
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
@@ -386,22 +191,16 @@ struct tile_window_with_static_distribution
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
using vectorized_tbuf =
array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
// StaticBuffer<address_space_enum::vgpr,
// vector_t,
// YElementSize / Traits::ScalarPerVector,
// true>;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
@@ -427,7 +226,7 @@ struct tile_window_with_static_distribution
Traits::PackedSize;
static_assert(d % Traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
0 /**/,
@@ -444,10 +243,10 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
@@ -492,9 +291,8 @@ struct tile_window_with_static_distribution
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
using Traits = load_store_traits;
using Traits = typename Base::Traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
@@ -516,7 +314,7 @@ struct tile_window_with_static_distribution
}();
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, 0, pre_nop_);
// move thread coordinate
@@ -525,10 +323,10 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
m0_inc_with_memory(size_per_issue);
@@ -569,7 +367,7 @@ struct tile_window_with_static_distribution
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using Traits = load_store_traits;
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
@@ -588,7 +386,7 @@ struct tile_window_with_static_distribution
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// move thread coordinate
@@ -597,10 +395,10 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
smem += size_per_issue; // Note we manually increase the per-issue offset
@@ -610,17 +408,18 @@ struct tile_window_with_static_distribution
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
typename Base::TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using Traits = typename Base::Traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
@@ -643,20 +442,20 @@ struct tile_window_with_static_distribution
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
@@ -668,10 +467,10 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
@@ -679,15 +478,17 @@ struct tile_window_with_static_distribution
}
template <index_t i_access_unsupport_ = -1>
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {}) const
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
dstr_tensor,
number<i_access_unsupport_> = {}) const
{
using Traits = load_store_traits;
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...]
@@ -710,16 +511,16 @@ struct tile_window_with_static_distribution
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view()
this->get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, 0, vec_value);
@@ -729,10 +530,10 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
@@ -740,16 +541,18 @@ struct tile_window_with_static_distribution
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
CK_TILE_DEVICE void
update(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
@@ -772,18 +575,18 @@ struct tile_window_with_static_distribution
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
@@ -795,10 +598,10 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
@@ -806,17 +609,19 @@ struct tile_window_with_static_distribution
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
CK_TILE_DEVICE void
update_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using Traits = load_store_traits;
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
@@ -839,18 +644,18 @@ struct tile_window_with_static_distribution
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
@@ -863,70 +668,44 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
// Custom move behavior
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step)
{
window_origin_ += step;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_coords_(iCoord)(I1),
step);
});
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&)
{
window_origin_ = new_window_origin;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(this->tile_dstr_),
array<index_t, Base::NDimY>{0}));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using Traits = typename Base::Traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
@@ -937,9 +716,10 @@ struct tile_window_with_static_distribution
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
@@ -947,27 +727,11 @@ struct tile_window_with_static_distribution
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
// this contains:
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>
pre_computed_coords_;
};
// TODO: use strategy
@@ -1037,62 +801,26 @@ CK_TILE_DEVICE void move_tile_window(
*/
template <typename BottomTensorView_, typename WindowLengths_>
struct tile_window_with_static_lengths
: public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
BottomTensorView_,
WindowLengths_>
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = typename BottomTensorView::DataType;
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using Base =
tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
BottomTensorView_,
WindowLengths_>;
CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default;
CK_TILE_DEVICE constexpr tile_window_with_static_lengths(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin}
const typename Base::BottomTensorView& bottom_tensor_view,
const typename Base::WindowLengths& window_lengths,
const typename Base::BottomTensorIndex& window_origin)
{
this->window_origin_ = window_origin;
this->window_lengths_ = window_lengths;
this->bottom_tensor_view_ = bottom_tensor_view;
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
}
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
};
template <typename TensorView_, typename WindowLengths_>

View File

@@ -0,0 +1,256 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/**
* @brief This class provides description of tile windowed view on the device memory.
*
* @note This class does not provide any functions to read or modify device memory.
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
*/
template <typename TileWindowType_, typename BottomTensorView_, typename WindowLengths_>
struct tile_window_base
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
// Delegate to child if it implements extra logic
static_cast<TileWindowType_*>(this)->set_window_origin_extended(new_window_origin);
}
// Default no-op; can be overridden in child
CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex&) {}
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
// Delegate to child if it implements extra movement logic
static_cast<TileWindowType_*>(this)->move_extended(step);
}
// Default no-op; can be overridden in child
CK_TILE_DEVICE void move_extended(const BottomTensorIndex&) {}
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
WindowLengths window_lengths_;
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
};
template <typename TileWindowType_,
typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_>
struct tile_window_with_tile_dstr_base
: public tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>
{
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using TileWindowBase = tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
// using BottomTensorIndex = array<index_t, TileWindowBase::NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord = decltype(make_tensor_coordinate(
typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{}));
static_assert(TileDstr::is_static(), "wrong!");
static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_HOST_DEVICE void init_raw() { this->bottom_tensor_view_.init_raw(); }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, TileWindowBase::NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
struct Traits
{
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<typename TileWindowBase::DataType>>::PackedSize;
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
using vector_t =
thread_buffer<typename TileWindowBase::DataType, ScalarPerVector / PackedSize>;
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto thread_tensor_lengths_ys =
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_),
false /*!!! no snaked curve! */>{};
}
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
};
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; }
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
};
} // namespace ck_tile

View File

@@ -13,6 +13,7 @@
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_window_base.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
@@ -37,171 +38,48 @@ namespace ck_tile {
// TODO: if using this struct, better use load_raw()/store_raw(), can control
// the the immediate offset on the fly
// space-filing-curve is non-snaked here!
//
// This struct inherits from tile_window_with_tile_dstr_base, which is an intermediary base class
// with the ultimate parent class being tile_window_base.
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_>
struct tile_window_linear
: public tile_window_with_tile_dstr_base<tile_window_linear<BottomTensorView_,
WindowLengths_,
StaticTileDistribution_,
LinearBottomDims_>,
BottomTensorView_,
WindowLengths_,
StaticTileDistribution_>
{
using Base = tile_window_with_tile_dstr_base<tile_window_linear<BottomTensorView_,
WindowLengths_,
StaticTileDistribution_,
LinearBottomDims_>,
BottomTensorView_,
WindowLengths_,
StaticTileDistribution_>;
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
using LinearBottomDims = remove_cvref_t<LinearBottomDims_>;
static_assert(LinearBottomDims::size() == BottomTensorView::get_num_of_dimension());
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
static_assert(LinearBottomDims::size() == Base::BottomTensorView::get_num_of_dimension());
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct traits
{
private:
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths,
bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths =
bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides =
bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
window_adaptor_vector_lengths{-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
window_adaptor_vector_strides{-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims =
typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto thread_tensor_lengths_ys =
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_),
false /*!!! no snaked curve! */>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
private:
static constexpr auto get_num_non_linear_access()
{
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
using ys_to_rhs_major =
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
using ys_to_rhs_major = typename decltype(
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto non_linear = [&]() {
index_t cnt = 1;
static_for<0, NDimY, 1>{}([&](auto i_dim_y) {
static_for<0, Base::NDimY, 1>{}([&](auto i_dim_y) {
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
if constexpr(LinearBottomDims{}[target_h_dim] == 0)
@@ -230,20 +108,20 @@ struct tile_window_linear
// -> prefixsum : seqneuce<0, 2, 4, 6, 8>
static constexpr auto get_non_linear_access_map()
{
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
using ys_to_rhs_major =
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
using ys_to_rhs_major = typename decltype(
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto non_linear_map = [&]() {
array<index_t, NumAccess> m_{0};
array<index_t, Base::Traits::NumAccess> m_{0};
index_t cumulative_len_ = 1;
index_t cumulative_non_linear_len_ = 1;
static_for<0, NDimY, 1>{}([&](auto i_y) {
constexpr auto i_dim_y = number<NDimY - i_y - 1>{}; // from right to left
static_for<0, Base::NDimY, 1>{}([&](auto i_y) {
constexpr auto i_dim_y = number<Base::NDimY - i_y - 1>{}; // from right to left
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim];
array<index_t, NumAccess> current_m_{0};
array<index_t, Base::Traits::NumAccess> current_m_{0};
constexpr auto current_len_ = sfc_access_lens[i_dim_y];
// copy cumulative length as current pattern
@@ -266,13 +144,12 @@ struct tile_window_linear
return m_;
}();
return TO_SEQUENCE(non_linear_map, NumAccess);
return TO_SEQUENCE(non_linear_map, Base::Traits::NumAccess);
}
static constexpr auto get_non_linear_access_histogram()
{
constexpr auto m_ = get_non_linear_access_map();
// m_.foo();
constexpr auto r_ =
typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{};
@@ -296,7 +173,7 @@ struct tile_window_linear
using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum());
};
static constexpr index_t NumAccess = traits::NumAccess;
static constexpr index_t NumAccess = Base::Traits::NumAccess;
static constexpr index_t NumAccess_NonLinear = traits::NumAccess_NonLinear;
using AccessMap_NonLinear = typename traits::AccessMap_NonLinear;
using AccessHistogram_NonLinear = typename traits::AccessHistogram_NonLinear;
@@ -304,30 +181,31 @@ struct tile_window_linear
CK_TILE_DEVICE constexpr tile_window_linear() = default;
CK_TILE_DEVICE constexpr tile_window_linear(const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
cached_coords_{},
cached_flags_{}
CK_TILE_DEVICE constexpr tile_window_linear(
const typename Base::BottomTensorView& bottom_tensor_view,
const typename Base::WindowLengths& window_lengths,
const typename Base::BottomTensorIndex& window_origin,
const typename Base::TileDstr& tile_distribution)
: cached_coords_{}, cached_flags_{}
{
this->bottom_tensor_view_ = bottom_tensor_view;
this->window_lengths_ = window_lengths;
this->window_origin_ = window_origin;
this->tile_dstr_ = tile_distribution;
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(make_tuple(get_warp_id(), get_lane_id()),
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
container_concat(
make_tuple(get_warp_id(), get_lane_id()),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// future load/store() calls (might allocate more registers)
using SFC_Ys = typename traits::SFC_Ys;
using SFC_Ys = typename Base::Traits::SFC_Ys;
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
@@ -343,16 +221,16 @@ struct tile_window_linear
// cached flag is independent from non-linear-coord
// but need be updated in move_tile, with proper dims
cached_flags_(i_access) = coordinate_has_valid_offset_assuming_top_index_is_valid(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
if constexpr(i_access != (NumAccess - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord_tmp,
bottom_tensor_thread_coord_tmp,
idx_diff_ps_ys);
@@ -360,54 +238,13 @@ struct tile_window_linear
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
template <index_t i_access>
CK_TILE_DEVICE static constexpr auto get_bottom_linear_coordinate(number<i_access>)
{
using SFC_Ys = typename traits::SFC_Ys;
using SFC_Ys = typename Base::Traits::SFC_Ys;
constexpr auto idx_ys = SFC_Ys::get_index(number<i_access>{});
using ys_to_rhs_major =
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
using ys_to_rhs_major = typename decltype(
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
constexpr auto modified_idx_ys = generate_tuple(
[&](auto i_dim_y) {
@@ -422,9 +259,9 @@ struct tile_window_linear
return number<idx_ys[i_dim_y]>{};
}
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr auto adaptor_ = TileDstr{}.get_ps_ys_to_xs_adaptor();
constexpr auto adaptor_ = typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor();
constexpr auto idx_ =
container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys);
@@ -441,8 +278,8 @@ struct tile_window_linear
{
// this case usually is a LDS window, everything is known at compile tile.
// we directly use BottomTensorView transform to compute the offset, in case padding
auto bottom_tensor_coord =
make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
auto bottom_tensor_coord = make_tensor_coordinate(
typename Base::BottomTensorView{}.get_tensor_descriptor(), linear_coord);
return bottom_tensor_coord.get_offset();
}
else
@@ -453,7 +290,7 @@ struct tile_window_linear
// since that would introduce runtime length (so can't use linear offset)
constexpr index_t linear_offset = [&]() {
constexpr auto x_idx_ = linear_coord;
constexpr auto x_len_ = TileDstr{}.get_lengths();
constexpr auto x_len_ = typename Base::TileDstr{}.get_lengths();
static_assert(x_idx_.size() == x_len_.size());
constexpr index_t x_dims_ = x_idx_.size();
index_t cu_stride_ = 1;
@@ -469,17 +306,16 @@ struct tile_window_linear
}
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
using vector_t = typename Base::Traits::vector_t;
using SFC_Ys = typename Base::Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto dst_tensor =
make_static_distributed_tensor<typename Base::DataTypeDataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
@@ -492,35 +328,29 @@ struct tile_window_linear
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
: idx_diff_ys[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
Base::Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j / traits::PackedSize];
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
@@ -533,10 +363,10 @@ struct tile_window_linear
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
using vector_t = typename Base::Traits::vector_t;
using SFC_Ys = typename Base::Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
@@ -551,35 +381,28 @@ struct tile_window_linear
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
: idx_diff_ys[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
Base::Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j / traits::PackedSize];
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
@@ -596,15 +419,17 @@ struct tile_window_linear
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
using vector_t = typename Base::Traits::vector_t;
using SFC_Ys = typename Base::Traits::SFC_Ys;
static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0);
typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % (Base::Traits::PackedSize * Base::Traits::ScalarPerVector) ==
0);
using vectorized_tbuf =
array<vector_t, YElementSize / (traits::PackedSize * traits::ScalarPerVector)>;
array<vector_t,
YElementSize / (Base::Traits::PackedSize * Base::Traits::ScalarPerVector)>;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
@@ -612,7 +437,7 @@ struct tile_window_linear
constexpr auto IAccess = number<i_access_>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && i_access_ == 0 &&
BottomTensorView::buffer_view::get_address_space() ==
Base::BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
return bool_constant<true>{};
else
@@ -628,11 +453,11 @@ struct tile_window_linear
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
traits::PackedSize;
static_assert(d % traits::ScalarPerVector == 0);
Base::Traits::PackedSize;
static_assert(d % Base::Traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / traits::ScalarPerVector>(),
this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Base::Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
linear_offset /**/,
bottom_tensor_flag,
@@ -663,7 +488,7 @@ struct tile_window_linear
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(BottomTensorView::buffer_view::get_address_space() ==
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
// issues * warps * lanes
@@ -689,7 +514,7 @@ struct tile_window_linear
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
using vector_t = typename traits::vector_t;
using vector_t = typename Base::Traits::vector_t;
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
@@ -708,7 +533,7 @@ struct tile_window_linear
auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
// move thread coordinate
@@ -732,7 +557,7 @@ struct tile_window_linear
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(BottomTensorView::buffer_view::get_address_space() ==
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
// issues * warps * lanes
@@ -757,7 +582,7 @@ struct tile_window_linear
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using vector_t = typename traits::vector_t;
using vector_t = typename Base::Traits::vector_t;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
@@ -771,7 +596,7 @@ struct tile_window_linear
auto bottom_tensor_flag = cached_flags_[IAccess];
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
0,
@@ -789,15 +614,16 @@ struct tile_window_linear
}
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
typename Base::TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
using vector_t = typename Base::Traits::vector_t;
using SFC_Ys = typename Base::Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
@@ -812,22 +638,23 @@ struct tile_window_linear
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
Base::Traits::PackedSize;
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
@@ -839,13 +666,15 @@ struct tile_window_linear
}
template <index_t i_access = -1>
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {}) const
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
dstr_tensor,
number<i_access> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
using vector_t = typename Base::Traits::vector_t;
using SFC_Ys = typename Base::Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
static constexpr bool oob_conditional_check = true;
// loop over thread tensor space [y0, y1, ...]
@@ -861,20 +690,21 @@ struct tile_window_linear
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
Base::Traits::PackedSize;
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view()
this->get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
};
@@ -883,15 +713,17 @@ struct tile_window_linear
}
template <index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
CK_TILE_DEVICE void
update(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
using vector_t = typename Base::Traits::vector_t;
using SFC_Ys = typename Base::Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
@@ -907,22 +739,24 @@ struct tile_window_linear
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
Base::Traits::PackedSize;
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
vec_value.template get_as<typename Base::DataTypeDataType>()(
j / Base::Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
@@ -934,16 +768,18 @@ struct tile_window_linear
}
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
CK_TILE_DEVICE void
update_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
using vector_t = typename Base::Traits::vector_t;
using SFC_Ys = typename Base::Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
constexpr auto tile_dstr = typename Base::TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
@@ -959,22 +795,24 @@ struct tile_window_linear
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
number<Base::NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
Base::Traits::PackedSize;
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
vec_value.template get_as<typename Base::DataTypeDataType>()(
j / Base::Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
@@ -985,14 +823,10 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE();
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
// *_extended() functions acts like a virtual function with a default implementation exisiting
// in the base class
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step)
{
window_origin_ += step;
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto IAccess = number<i_access>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
@@ -1001,7 +835,7 @@ struct tile_window_linear
if constexpr(need_update_non_linear_coord)
{
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
cached_coords_(non_linear_id),
step);
}
@@ -1010,30 +844,29 @@ struct tile_window_linear
auto tmp_coords = cached_coords_[non_linear_id];
constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess);
move_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
cached_flags_(IAccess) = coordinate_has_valid_offset_assuming_top_index_is_valid(
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
});
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&)
{
window_origin_ = new_window_origin;
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
TileDstr{}.get_ps_ys_to_xs_adaptor(),
container_concat(make_tuple(get_warp_id(), get_lane_id()),
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor(),
container_concat(
make_tuple(get_warp_id(), get_lane_id()),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// future load/store() calls (might allocate more registers)
using SFC_Ys = typename traits::SFC_Ys;
using SFC_Ys = typename Base::Traits::SFC_Ys;
static_for<0, NumAccess, 1>{}([&](auto i_access) {
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
@@ -1049,10 +882,10 @@ struct tile_window_linear
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord_tmp,
bottom_tensor_thread_coord_tmp,
idx_diff_ps_ys);
@@ -1060,26 +893,9 @@ struct tile_window_linear
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
// this contains:
array<BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
array<bool, traits::NumAccess> cached_flags_;
array<typename Base::BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
array<bool, Base::Traits::NumAccess> cached_flags_;
};
#undef WINDOW_DISPATCH_ISSUE