mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Refactor tile_window.hpp, tile_window_linear.hpp into a CK Tile Hierarchy (#2214)
* window_origin variable now in base class
* abstracted more functions
* consolidated tile_window_static_distribution and tile_window_static_lengths
* clang format
* skeleton code for tile_window and tile_window_linear consolidation
* more abstraction
* moved variables from child to parent
* clang format
* removed comments
* removed debug code
* removed debug code
* abstracting traits WIP
* consolidated traits
* removed comments and clang formatted
[ROCm/composable_kernel commit: 534d4594d0]
This commit is contained in:
@@ -54,6 +54,7 @@
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_utils.hpp"
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
@@ -34,166 +35,60 @@ template <typename BottomTensorView_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
struct tile_window_with_static_distribution
|
||||
: public tile_window_with_tile_dstr_base<
|
||||
tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_>
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
using Base = tile_window_with_tile_dstr_base<
|
||||
tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_>;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static_assert(NumCoord == 1);
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
|
||||
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord =
|
||||
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
|
||||
|
||||
struct load_store_traits
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_window_with_static_distribution::
|
||||
get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
|
||||
// using vector_t = typename vector_type_t::type;
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_)>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
|
||||
};
|
||||
|
||||
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
|
||||
static_assert(Base::Traits::NumAccess % NumCoord == 0,
|
||||
"wrong! # of access is not divisible by NumCoord");
|
||||
static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
pre_computed_coords_{}
|
||||
const typename Base::BottomTensorView& bottom_tensor_view,
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution)
|
||||
: pre_computed_coords_{}
|
||||
{
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
this->window_origin_ = window_origin;
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
this->tile_dstr_ = tile_distribution;
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_distribution),
|
||||
array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
array<index_t, Base::NDimY>{0}));
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -204,9 +99,10 @@ struct tile_window_with_static_distribution
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
@@ -214,95 +110,12 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
@@ -314,11 +127,11 @@ struct tile_window_with_static_distribution
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -334,9 +147,8 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
@@ -344,33 +156,26 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
|
||||
vec_value
|
||||
.template get_as<typename Base::DataType>()[j / Traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -386,22 +191,16 @@ struct tile_window_with_static_distribution
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
static constexpr index_t YElementSize =
|
||||
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
|
||||
using vectorized_tbuf =
|
||||
array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
|
||||
// StaticBuffer<address_space_enum::vgpr,
|
||||
// vector_t,
|
||||
// YElementSize / Traits::ScalarPerVector,
|
||||
// true>;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
|
||||
|
||||
@@ -427,7 +226,7 @@ struct tile_window_with_static_distribution
|
||||
Traits::PackedSize;
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
|
||||
bottom_tensor_thread_coord,
|
||||
0 /**/,
|
||||
@@ -444,10 +243,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -492,9 +291,8 @@ struct tile_window_with_static_distribution
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
@@ -516,7 +314,7 @@ struct tile_window_with_static_distribution
|
||||
}();
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, pre_nop_);
|
||||
|
||||
// move thread coordinate
|
||||
@@ -525,10 +323,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
m0_inc_with_memory(size_per_issue);
|
||||
@@ -569,7 +367,7 @@ struct tile_window_with_static_distribution
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
@@ -588,7 +386,7 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
@@ -597,10 +395,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
smem += size_per_issue; // Note we manually increase the per-issue offset
|
||||
@@ -610,17 +408,18 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
|
||||
typename Base::TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -643,20 +442,20 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
@@ -668,10 +467,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -679,15 +478,17 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1>
|
||||
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
store_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access_unsupport_> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
static constexpr bool oob_conditional_check = true;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
@@ -710,16 +511,16 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
this->get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, 0, vec_value);
|
||||
|
||||
@@ -729,10 +530,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -740,16 +541,18 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
update(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -772,18 +575,18 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
@@ -795,10 +598,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -806,17 +609,19 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
|
||||
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
update_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -839,18 +644,18 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
@@ -863,70 +668,44 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
// Custom move behavior
|
||||
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
|
||||
pre_computed_coords_(iCoord)(I1),
|
||||
step);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(this->tile_dstr_),
|
||||
array<index_t, Base::NDimY>{0}));
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -937,9 +716,10 @@ struct tile_window_with_static_distribution
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
@@ -947,27 +727,11 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
// this contains:
|
||||
// per-thread coordinate for window adaptor
|
||||
// per-thread coordinate for bottom tensor
|
||||
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
|
||||
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>
|
||||
pre_computed_coords_;
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
@@ -1037,62 +801,26 @@ CK_TILE_DEVICE void move_tile_window(
|
||||
*/
|
||||
template <typename BottomTensorView_, typename WindowLengths_>
|
||||
struct tile_window_with_static_lengths
|
||||
: public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_>
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
using DataType = typename BottomTensorView::DataType;
|
||||
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
using Base =
|
||||
tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_>;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_lengths(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin}
|
||||
const typename Base::BottomTensorView& bottom_tensor_view,
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin)
|
||||
{
|
||||
this->window_origin_ = window_origin;
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
};
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
|
||||
256
include/ck_tile/core/tensor/tile_window_base.hpp
Normal file
256
include/ck_tile/core/tensor/tile_window_base.hpp
Normal file
@@ -0,0 +1,256 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief This class provides description of tile windowed view on the device memory.
|
||||
*
|
||||
* @note This class does not provide any functions to read or modify device memory.
|
||||
*
|
||||
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
|
||||
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
|
||||
*/
|
||||
template <typename TileWindowType_, typename BottomTensorView_, typename WindowLengths_>
|
||||
struct tile_window_base
|
||||
{
|
||||
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
// Delegate to child if it implements extra logic
|
||||
static_cast<TileWindowType_*>(this)->set_window_origin_extended(new_window_origin);
|
||||
}
|
||||
// Default no-op; can be overridden in child
|
||||
CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex&) {}
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
// Delegate to child if it implements extra movement logic
|
||||
static_cast<TileWindowType_*>(this)->move_extended(step);
|
||||
}
|
||||
|
||||
// Default no-op; can be overridden in child
|
||||
CK_TILE_DEVICE void move_extended(const BottomTensorIndex&) {}
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
};
|
||||
|
||||
template <typename TileWindowType_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_>
|
||||
struct tile_window_with_tile_dstr_base
|
||||
: public tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>
|
||||
{
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
using TileWindowBase = tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
// using BottomTensorIndex = array<index_t, TileWindowBase::NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord = decltype(make_tensor_coordinate(
|
||||
typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{}));
|
||||
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
CK_TILE_HOST_DEVICE void init_raw() { this->bottom_tensor_view_.init_raw(); }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, TileWindowBase::NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
struct Traits
|
||||
{
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename TileWindowBase::DataType>>::PackedSize;
|
||||
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
using vector_t =
|
||||
thread_buffer<typename TileWindowBase::DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_),
|
||||
false /*!!! no snaked curve! */>{};
|
||||
}
|
||||
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
};
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; }
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
@@ -37,171 +38,48 @@ namespace ck_tile {
|
||||
// TODO: if using this struct, better use load_raw()/store_raw(), can control
|
||||
// the the immediate offset on the fly
|
||||
// space-filing-curve is non-snaked here!
|
||||
//
|
||||
// This struct inherits from tile_window_with_tile_dstr_base, which is an intermediary base class
|
||||
// with the ultimate parent class being tile_window_base.
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename LinearBottomDims_>
|
||||
struct tile_window_linear
|
||||
: public tile_window_with_tile_dstr_base<tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
LinearBottomDims_>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_>
|
||||
{
|
||||
using Base = tile_window_with_tile_dstr_base<tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
LinearBottomDims_>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_>;
|
||||
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
using LinearBottomDims = remove_cvref_t<LinearBottomDims_>;
|
||||
|
||||
static_assert(LinearBottomDims::size() == BottomTensorView::get_num_of_dimension());
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
static_assert(LinearBottomDims::size() == Base::BottomTensorView::get_num_of_dimension());
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
|
||||
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord =
|
||||
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
|
||||
|
||||
struct traits
|
||||
{
|
||||
private:
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths,
|
||||
bottom_tensor_top_dim_vector_strides] =
|
||||
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths =
|
||||
bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides =
|
||||
bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
|
||||
window_adaptor_vector_lengths{-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
|
||||
window_adaptor_vector_strides{-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims =
|
||||
typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_),
|
||||
false /*!!! no snaked curve! */>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
|
||||
private:
|
||||
static constexpr auto get_num_non_linear_access()
|
||||
{
|
||||
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major =
|
||||
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major = typename decltype(
|
||||
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
|
||||
constexpr auto non_linear = [&]() {
|
||||
index_t cnt = 1;
|
||||
static_for<0, NDimY, 1>{}([&](auto i_dim_y) {
|
||||
static_for<0, Base::NDimY, 1>{}([&](auto i_dim_y) {
|
||||
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
|
||||
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
|
||||
if constexpr(LinearBottomDims{}[target_h_dim] == 0)
|
||||
@@ -230,20 +108,20 @@ struct tile_window_linear
|
||||
// -> prefixsum : seqneuce<0, 2, 4, 6, 8>
|
||||
static constexpr auto get_non_linear_access_map()
|
||||
{
|
||||
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major =
|
||||
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major = typename decltype(
|
||||
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
constexpr auto non_linear_map = [&]() {
|
||||
array<index_t, NumAccess> m_{0};
|
||||
array<index_t, Base::Traits::NumAccess> m_{0};
|
||||
index_t cumulative_len_ = 1;
|
||||
index_t cumulative_non_linear_len_ = 1;
|
||||
static_for<0, NDimY, 1>{}([&](auto i_y) {
|
||||
constexpr auto i_dim_y = number<NDimY - i_y - 1>{}; // from right to left
|
||||
static_for<0, Base::NDimY, 1>{}([&](auto i_y) {
|
||||
constexpr auto i_dim_y = number<Base::NDimY - i_y - 1>{}; // from right to left
|
||||
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
|
||||
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
|
||||
constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim];
|
||||
|
||||
array<index_t, NumAccess> current_m_{0};
|
||||
array<index_t, Base::Traits::NumAccess> current_m_{0};
|
||||
constexpr auto current_len_ = sfc_access_lens[i_dim_y];
|
||||
|
||||
// copy cumulative length as current pattern
|
||||
@@ -266,13 +144,12 @@ struct tile_window_linear
|
||||
return m_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(non_linear_map, NumAccess);
|
||||
return TO_SEQUENCE(non_linear_map, Base::Traits::NumAccess);
|
||||
}
|
||||
|
||||
static constexpr auto get_non_linear_access_histogram()
|
||||
{
|
||||
constexpr auto m_ = get_non_linear_access_map();
|
||||
// m_.foo();
|
||||
|
||||
constexpr auto r_ =
|
||||
typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{};
|
||||
@@ -296,7 +173,7 @@ struct tile_window_linear
|
||||
using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum());
|
||||
};
|
||||
|
||||
static constexpr index_t NumAccess = traits::NumAccess;
|
||||
static constexpr index_t NumAccess = Base::Traits::NumAccess;
|
||||
static constexpr index_t NumAccess_NonLinear = traits::NumAccess_NonLinear;
|
||||
using AccessMap_NonLinear = typename traits::AccessMap_NonLinear;
|
||||
using AccessHistogram_NonLinear = typename traits::AccessHistogram_NonLinear;
|
||||
@@ -304,30 +181,31 @@ struct tile_window_linear
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_linear() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_linear(const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
cached_coords_{},
|
||||
cached_flags_{}
|
||||
CK_TILE_DEVICE constexpr tile_window_linear(
|
||||
const typename Base::BottomTensorView& bottom_tensor_view,
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution)
|
||||
: cached_coords_{}, cached_flags_{}
|
||||
{
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->window_origin_ = window_origin;
|
||||
this->tile_dstr_ = tile_distribution;
|
||||
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(make_tuple(get_warp_id(), get_lane_id()),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
|
||||
container_concat(
|
||||
make_tuple(get_warp_id(), get_lane_id()),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumAccess, 1>{}([&](auto i_access) {
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
|
||||
@@ -343,16 +221,16 @@ struct tile_window_linear
|
||||
// cached flag is independent from non-linear-coord
|
||||
// but need be updated in move_tile, with proper dims
|
||||
cached_flags_(i_access) = coordinate_has_valid_offset_assuming_top_index_is_valid(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
|
||||
|
||||
if constexpr(i_access != (NumAccess - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord_tmp,
|
||||
bottom_tensor_thread_coord_tmp,
|
||||
idx_diff_ps_ys);
|
||||
@@ -360,54 +238,13 @@ struct tile_window_linear
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
template <index_t i_access>
|
||||
CK_TILE_DEVICE static constexpr auto get_bottom_linear_coordinate(number<i_access>)
|
||||
{
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
constexpr auto idx_ys = SFC_Ys::get_index(number<i_access>{});
|
||||
using ys_to_rhs_major =
|
||||
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
using ys_to_rhs_major = typename decltype(
|
||||
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
|
||||
constexpr auto modified_idx_ys = generate_tuple(
|
||||
[&](auto i_dim_y) {
|
||||
@@ -422,9 +259,9 @@ struct tile_window_linear
|
||||
return number<idx_ys[i_dim_y]>{};
|
||||
}
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr auto adaptor_ = TileDstr{}.get_ps_ys_to_xs_adaptor();
|
||||
constexpr auto adaptor_ = typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor();
|
||||
constexpr auto idx_ =
|
||||
container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys);
|
||||
|
||||
@@ -441,8 +278,8 @@ struct tile_window_linear
|
||||
{
|
||||
// this case usually is a LDS window, everything is known at compile tile.
|
||||
// we directly use BottomTensorView transform to compute the offset, in case padding
|
||||
auto bottom_tensor_coord =
|
||||
make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
|
||||
auto bottom_tensor_coord = make_tensor_coordinate(
|
||||
typename Base::BottomTensorView{}.get_tensor_descriptor(), linear_coord);
|
||||
return bottom_tensor_coord.get_offset();
|
||||
}
|
||||
else
|
||||
@@ -453,7 +290,7 @@ struct tile_window_linear
|
||||
// since that would introduce runtime length (so can't use linear offset)
|
||||
constexpr index_t linear_offset = [&]() {
|
||||
constexpr auto x_idx_ = linear_coord;
|
||||
constexpr auto x_len_ = TileDstr{}.get_lengths();
|
||||
constexpr auto x_len_ = typename Base::TileDstr{}.get_lengths();
|
||||
static_assert(x_idx_.size() == x_len_.size());
|
||||
constexpr index_t x_dims_ = x_idx_.size();
|
||||
index_t cu_stride_ = 1;
|
||||
@@ -469,17 +306,16 @@ struct tile_window_linear
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
auto dst_tensor =
|
||||
make_static_distributed_tensor<typename Base::DataTypeDataType>(tile_dstr);
|
||||
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
@@ -492,35 +328,29 @@ struct tile_window_linear
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
|
||||
: idx_diff_ys[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j / traits::PackedSize];
|
||||
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
|
||||
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
@@ -533,10 +363,10 @@ struct tile_window_linear
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
|
||||
@@ -551,35 +381,28 @@ struct tile_window_linear
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
|
||||
: idx_diff_ys[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j / traits::PackedSize];
|
||||
dst_tensor.get_thread_buffer().template at<d>() = vec_value.template get_as<
|
||||
typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
@@ -596,15 +419,17 @@ struct tile_window_linear
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
static constexpr index_t YElementSize =
|
||||
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0);
|
||||
typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % (Base::Traits::PackedSize * Base::Traits::ScalarPerVector) ==
|
||||
0);
|
||||
using vectorized_tbuf =
|
||||
array<vector_t, YElementSize / (traits::PackedSize * traits::ScalarPerVector)>;
|
||||
array<vector_t,
|
||||
YElementSize / (Base::Traits::PackedSize * Base::Traits::ScalarPerVector)>;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
|
||||
|
||||
@@ -612,7 +437,7 @@ struct tile_window_linear
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto pre_nop_ = [&]() {
|
||||
if constexpr(pre_nop && i_access_ == 0 &&
|
||||
BottomTensorView::buffer_view::get_address_space() ==
|
||||
Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global)
|
||||
return bool_constant<true>{};
|
||||
else
|
||||
@@ -628,11 +453,11 @@ struct tile_window_linear
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
|
||||
traits::PackedSize;
|
||||
static_assert(d % traits::ScalarPerVector == 0);
|
||||
Base::Traits::PackedSize;
|
||||
static_assert(d % Base::Traits::ScalarPerVector == 0);
|
||||
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / traits::ScalarPerVector>(),
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / Base::Traits::ScalarPerVector>(),
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset /**/,
|
||||
bottom_tensor_flag,
|
||||
@@ -663,7 +488,7 @@ struct tile_window_linear
|
||||
// currently we only support everything is non linear dim
|
||||
// actually it's not performant if we have linear dim(e.g. fast changing)
|
||||
static_assert(NumAccess_NonLinear == NumAccess);
|
||||
static_assert(BottomTensorView::buffer_view::get_address_space() ==
|
||||
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global);
|
||||
|
||||
// issues * warps * lanes
|
||||
@@ -689,7 +514,7 @@ struct tile_window_linear
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
|
||||
@@ -708,7 +533,7 @@ struct tile_window_linear
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
|
||||
|
||||
// move thread coordinate
|
||||
@@ -732,7 +557,7 @@ struct tile_window_linear
|
||||
// currently we only support everything is non linear dim
|
||||
// actually it's not performant if we have linear dim(e.g. fast changing)
|
||||
static_assert(NumAccess_NonLinear == NumAccess);
|
||||
static_assert(BottomTensorView::buffer_view::get_address_space() ==
|
||||
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global);
|
||||
|
||||
// issues * warps * lanes
|
||||
@@ -757,7 +582,7 @@ struct tile_window_linear
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
@@ -771,7 +596,7 @@ struct tile_window_linear
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
@@ -789,15 +614,16 @@ struct tile_window_linear
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
|
||||
typename Base::TileDstr>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
@@ -812,22 +638,23 @@ struct tile_window_linear
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
@@ -839,13 +666,15 @@ struct tile_window_linear
|
||||
}
|
||||
|
||||
template <index_t i_access = -1>
|
||||
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
store_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
static constexpr bool oob_conditional_check = true;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
@@ -861,20 +690,21 @@ struct tile_window_linear
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
Base::Traits::PackedSize;
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
this->get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
|
||||
};
|
||||
@@ -883,15 +713,17 @@ struct tile_window_linear
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
update(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
@@ -907,22 +739,24 @@ struct tile_window_linear
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataTypeDataType>()(
|
||||
j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
@@ -934,16 +768,18 @@ struct tile_window_linear
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
update_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
@@ -959,22 +795,24 @@ struct tile_window_linear
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataTypeDataType>()(
|
||||
j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
@@ -985,14 +823,10 @@ struct tile_window_linear
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
// *_extended() functions acts like a virtual function with a default implementation exisiting
|
||||
// in the base class
|
||||
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
static_for<0, NumAccess, 1>{}([&](auto i_access) {
|
||||
constexpr auto IAccess = number<i_access>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
|
||||
@@ -1001,7 +835,7 @@ struct tile_window_linear
|
||||
|
||||
if constexpr(need_update_non_linear_coord)
|
||||
{
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
|
||||
cached_coords_(non_linear_id),
|
||||
step);
|
||||
}
|
||||
@@ -1010,30 +844,29 @@ struct tile_window_linear
|
||||
auto tmp_coords = cached_coords_[non_linear_id];
|
||||
constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess);
|
||||
move_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
|
||||
|
||||
cached_flags_(IAccess) = coordinate_has_valid_offset_assuming_top_index_is_valid(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
TileDstr{}.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(make_tuple(get_warp_id(), get_lane_id()),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
|
||||
typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(
|
||||
make_tuple(get_warp_id(), get_lane_id()),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumAccess, 1>{}([&](auto i_access) {
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
|
||||
@@ -1049,10 +882,10 @@ struct tile_window_linear
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord_tmp,
|
||||
bottom_tensor_thread_coord_tmp,
|
||||
idx_diff_ps_ys);
|
||||
@@ -1060,26 +893,9 @@ struct tile_window_linear
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
// this contains:
|
||||
array<BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
|
||||
array<bool, traits::NumAccess> cached_flags_;
|
||||
array<typename Base::BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
|
||||
array<bool, Base::Traits::NumAccess> cached_flags_;
|
||||
};
|
||||
|
||||
#undef WINDOW_DISPATCH_ISSUE
|
||||
|
||||
Reference in New Issue
Block a user