diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 2ea8bf15a7..aa9411b2e1 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -54,6 +54,7 @@ #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_scatter_gather.hpp" +#include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp" diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 716b1f4ecb..d8a5c14f9b 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -13,6 +13,7 @@ #include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -34,166 +35,60 @@ template struct tile_window_with_static_distribution + : public tile_window_with_tile_dstr_base< + tile_window_with_static_distribution, + BottomTensorView_, + WindowLengths_, + StaticTileDistribution_> { - using BottomTensorView = remove_reference_t; - using WindowLengths = remove_cvref_t; - using TileDstr = remove_cvref_t; - - using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; - using BottomTensorDesc = typename BottomTensorView::TensorDesc; - - using DataType = remove_cvref_t; - - static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); - static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); - - static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); - static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + using Base = tile_window_with_tile_dstr_base< + tile_window_with_static_distribution, + BottomTensorView_, + WindowLengths_, + StaticTileDistribution_>; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static_assert(NumCoord == 1); - // TODO: check WindowLengths and StaticTileDistribution are consistent - - static_assert(ck_tile::is_known_at_compile_time::value, - "wrong! lengths should be static"); - static_assert(TileDstr::is_static(), "wrong!"); - - static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), - "wrong! inconsistent # of diemsnions"); - - using AdaptorTopIndex = array; - using BottomTensorIndex = array; - - using WindowAdaptorCoord = - decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); - - using BottomTensorCoord = - decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); - - struct load_store_traits - { - private: - static constexpr auto get_vector_dim_y_scalar_per_vector() - { - const auto [ys_vector_lengths, ys_vector_strides] = - tile_window_with_static_distribution:: - get_window_adaptor_ys_safe_vector_length_strides(); - - index_t VectorDimY_ = 0; - index_t ScalarPerVector_ = 1; - - for(index_t i = 0; i < NDimY; ++i) - { - if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) - { - ScalarPerVector_ = ys_vector_lengths[i]; - VectorDimY_ = i; - } - } - - return make_tuple(VectorDimY_, ScalarPerVector_); - } - - public: - static constexpr index_t PackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); - static constexpr index_t ScalarPerVector = - get_vector_dim_y_scalar_per_vector().template at<1>(); - - // using vector_type_t = vector_type_maker_t; - // using vector_t = typename vector_type_t::type; - using vector_t = thread_buffer; - - private: - static constexpr auto scalars_per_access_ = [] { - constexpr auto scalars_per_access_arr = generate_array( - [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); - - /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() - constexpr auto NDimY_ = NDimY; - - return TO_SEQUENCE(scalars_per_access_arr, NDimY_); - }(); - - static constexpr auto get_space_filling_curve() - { - constexpr auto tile_dstr = TileDstr{}; - - constexpr auto thread_tensor_lengths_ys = - to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); - - // FIXME: need logic to judge dim access order - using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; - - return space_filling_curve{}; - } - - public: - using SFC_Ys = decltype(get_space_filling_curve()); - - static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); - - static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); - static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord"); - }; - - static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord; + static_assert(Base::Traits::NumAccess % NumCoord == 0, + "wrong! # of access is not divisible by NumCoord"); + static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord; CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default; CK_TILE_DEVICE constexpr tile_window_with_static_distribution( - const BottomTensorView& bottom_tensor_view, - const WindowLengths& window_lengths, - const BottomTensorIndex& window_origin, - const TileDstr& tile_distribution) - : bottom_tensor_view_{bottom_tensor_view}, - window_lengths_{window_lengths}, - window_origin_{window_origin}, - tile_dstr_{tile_distribution}, - pre_computed_coords_{} + const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::WindowLengths& window_lengths, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution) + : pre_computed_coords_{} { -#if 0 // debug - // TODO: this use more register for FA, but less register for GEMM - // need investigation - // only support warp-tile and block-tile - static_assert(NDimP == 1 or NDimP == 2, "wrong!"); - WindowAdaptorCoord window_adaptor_thread_coord_tmp; - - if constexpr(NDimP == 1) - { - window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); - } - else if constexpr(NDimP == 2) - { - window_adaptor_thread_coord_tmp = - make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(), - AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); - } -#else - // TODO: this use less register for FA, but more register for GEMM - // need investigation + this->window_origin_ = window_origin; + this->window_lengths_ = window_lengths; + this->bottom_tensor_view_ = bottom_tensor_view; + this->tile_dstr_ = tile_distribution; const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), container_concat(detail::get_partition_index(tile_distribution), - array{0})); -#endif + array{0})); - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up // future load/store() calls (might allocate more registers) - using Traits = load_store_traits; + using Traits = typename Base::Traits; using SFC_Ys = typename Traits::SFC_Ys; static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -204,9 +99,10 @@ struct tile_window_with_static_distribution SFC_Ys::get_step_between(number<0>{}, number{}); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); pre_computed_coords_(iCoord) = @@ -214,95 +110,12 @@ struct tile_window_with_static_distribution }); } - CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } - - CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() - { - return TileDstr::is_static(); - } - - CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } - - CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } - - CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } - - CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } - - CK_TILE_DEVICE constexpr void - set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) - { - bottom_tensor_view_.buf_.p_data_ = data; - } - - // move thread's window adaptor coordinate and bottom tensor coordinate - // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] - template - CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( - WindowAdaptorCoord& window_adaptor_thread_coord, - BottomTensorCoord& bottom_tensor_thread_coord, - const ATopIndex& idx_diff_adaptor_top) const - { - array idx_diff_adaptor_bottom; - - move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), - window_adaptor_thread_coord, - idx_diff_adaptor_top, - idx_diff_adaptor_bottom); - - move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), - bottom_tensor_thread_coord, - idx_diff_adaptor_bottom); - } - - // return vector dimension among [y0, y1, ...] - CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() - { - // bottom tensor top dimension vector lengths and strides - const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = - BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); - - // window vector lengths/strides - const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; - const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; - - // window adaptor [p0, p1, ..., y0, y1, ...] - array window_adaptor_vector_lengths{ - -1}; - array window_adaptor_vector_strides{ - -1}; - - constexpr auto window_adaptor_bottom_dims = - WindowAdaptor::get_bottom_dimension_hidden_ids(); - - set_container_subset(window_adaptor_vector_lengths, - window_adaptor_bottom_dims, - window_adaptor_bottom_dim_vector_lengths); - set_container_subset(window_adaptor_vector_strides, - window_adaptor_bottom_dims, - window_adaptor_bottom_dim_vector_strides); - - const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = - WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( - window_adaptor_vector_lengths, window_adaptor_vector_strides); - - // [y0, y1, ...] - constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; - - return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), - get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); - } - - CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; } - template CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const { - constexpr auto tile_dstr = TileDstr{}; - auto dst_tensor = make_static_distributed_tensor(tile_dstr); + constexpr auto tile_dstr = typename Base::TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, number{}, bool_constant{}); return dst_tensor; } @@ -314,11 +127,11 @@ struct tile_window_with_static_distribution number = {}, bool_constant = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -334,9 +147,8 @@ struct tile_window_with_static_distribution // read from bottom tensor const vector_t vec_value = - get_bottom_tensor_view().template get_vectorized_elements( + this->get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, 0, bool_constant{}); -#if 1 // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( @@ -344,33 +156,26 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j / Traits::PackedSize]; + vec_value + .template get_as()[j / Traits::PackedSize]; }); -#else - constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); - static_assert(d % Traits::ScalarPerVector == 0); - - dst_tensor.get_thread_buffer().template get_as()( - number{}) = bit_cast(vec_value); -#endif // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -386,22 +191,16 @@ struct tile_window_with_static_distribution bool_constant = {}, bool_constant = {}) const { - using Traits = load_store_traits; - - // using vector_type_t = typename Traits::vector_type_t; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; static constexpr index_t YElementSize = - TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); + typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0); using vectorized_tbuf = array; - // StaticBuffer; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; auto& dst_vec_tbuf = reinterpret_cast(dst_tensor.get_thread_buffer()); @@ -427,7 +226,7 @@ struct tile_window_with_static_distribution Traits::PackedSize; static_assert(d % Traits::ScalarPerVector == 0); - get_bottom_tensor_view().template get_vectorized_elements_raw( + this->get_bottom_tensor_view().template get_vectorized_elements_raw( dst_vec_tbuf.template at(), bottom_tensor_thread_coord, 0 /**/, @@ -444,10 +243,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -492,9 +291,8 @@ struct tile_window_with_static_distribution const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); m0_set_with_memory(m0_init_value); // This should be wave independent - using Traits = load_store_traits; + using Traits = typename Base::Traits; - // using vector_type_t = typename Traits::vector_type_t; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; @@ -516,7 +314,7 @@ struct tile_window_with_static_distribution }(); // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements_raw( + this->get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, 0, pre_nop_); // move thread coordinate @@ -525,10 +323,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); m0_inc_with_memory(size_per_issue); @@ -569,7 +367,7 @@ struct tile_window_with_static_distribution const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; @@ -588,7 +386,7 @@ struct tile_window_with_static_distribution constexpr auto iAccess = number{}; // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements( + this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, 0, bool_constant{}); // move thread coordinate @@ -597,10 +395,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); smem += size_per_issue; // Note we manually increase the per-issue offset @@ -610,17 +408,18 @@ struct tile_window_with_static_distribution } template - CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, number = {}, bool_constant = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; // using vector_type_t = typename Traits::vector_type_t; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -643,20 +442,20 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; - vec_value.template get_as()(j / Traits::PackedSize) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // const vector_t vec_value = vec.template get_as().template at<0>(); // write into bottom tensor - get_bottom_tensor_view().template set_vectorized_elements( + this->get_bottom_tensor_view().template set_vectorized_elements( bottom_tensor_thread_coord, 0, vec_value, @@ -668,10 +467,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -679,15 +478,17 @@ struct tile_window_with_static_distribution } template - CK_TILE_DEVICE void store_raw(const static_distributed_tensor& dstr_tensor, - number = {}) const + CK_TILE_DEVICE void + store_raw(const static_distributed_tensor& + dstr_tensor, + number = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; static constexpr bool oob_conditional_check = true; // loop over thread tensor space [y0, y1, ...] @@ -710,16 +511,16 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; - vec_value.template get_as()(j / Traits::PackedSize) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view() + this->get_bottom_tensor_view() .template set_vectorized_elements_raw( bottom_tensor_thread_coord, 0, vec_value); @@ -729,10 +530,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -740,16 +541,18 @@ struct tile_window_with_static_distribution } template - CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, - number = {}, - bool_constant = {}) const + CK_TILE_DEVICE void + update(const static_distributed_tensor& + dstr_tensor, + number = {}, + bool_constant = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -772,18 +575,18 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; - vec_value.template get_as()(j / Traits::PackedSize) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template update_vectorized_elements( + this->get_bottom_tensor_view().template update_vectorized_elements( bottom_tensor_thread_coord, 0, vec_value, @@ -795,10 +598,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -806,17 +609,19 @@ struct tile_window_with_static_distribution } template - CK_TILE_DEVICE void update_raw(const static_distributed_tensor& dstr_tensor, - number = {}, - bool_constant = {}, - bool_constant = {}) const + CK_TILE_DEVICE void + update_raw(const static_distributed_tensor& + dstr_tensor, + number = {}, + bool_constant = {}, + bool_constant = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -839,18 +644,18 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; - vec_value.template get_as()(j / Traits::PackedSize) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template update_vectorized_elements_raw( + this->get_bottom_tensor_view().template update_vectorized_elements_raw( bottom_tensor_thread_coord, 0, vec_value, @@ -863,70 +668,44 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); }); } - // move thread's botom tensor coordiante - // [x0', x1', ... ] ==> [offset] - // also move window-origin - CK_TILE_DEVICE void move(const BottomTensorIndex& step) + // Custom move behavior + CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step) { - window_origin_ += step; - static_for<0, NumCoord, 1>{}([&](auto iCoord) { - move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), pre_computed_coords_(iCoord)(I1), step); }); } - CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&) { - window_origin_ = new_window_origin; - -#if 0 // debug - // TODO: this use more register for FA, but less register for GEMM - // need investigation - // only support warp-tile and block-tile - static_assert(NDimP == 1 or NDimP == 2, "wrong!"); - - WindowAdaptorCoord window_adaptor_thread_coord_tmp; - - if constexpr(NDimP == 1) - { - window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); - } - else if constexpr(NDimP == 2) - { - window_adaptor_thread_coord_tmp = - make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), - AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); - } -#else // TODO: this use less register for FA, but more register for GEMM // need investigation const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - tile_dstr_.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_dstr_), array{0})); -#endif + this->tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(this->tile_dstr_), + array{0})); - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = - window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up // future load/store() calls (might allocate more registers) - using Traits = load_store_traits; + using Traits = typename Base::Traits; using SFC_Ys = typename Traits::SFC_Ys; static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -937,9 +716,10 @@ struct tile_window_with_static_distribution SFC_Ys::get_step_between(number<0>{}, number{}); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); pre_computed_coords_(iCoord) = @@ -947,27 +727,11 @@ struct tile_window_with_static_distribution }); } - CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } - - // this is the bottom tensor view - // [x0', x1', ...] ==> [offset] - BottomTensorView bottom_tensor_view_; - - // - WindowLengths window_lengths_; - - // origin ([x0', x1', ...]) of window on bottom tensor - BottomTensorIndex window_origin_; - - // Tile tensor distribution, which contains: - // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] - // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] - TileDstr tile_dstr_; - // this contains: // per-thread coordinate for window adaptor // per-thread coordinate for bottom tensor - array, NumCoord> pre_computed_coords_; + array, NumCoord> + pre_computed_coords_; }; // TODO: use strategy @@ -1037,62 +801,26 @@ CK_TILE_DEVICE void move_tile_window( */ template struct tile_window_with_static_lengths + : public tile_window_base, + BottomTensorView_, + WindowLengths_> { - using BottomTensorView = remove_reference_t; - using WindowLengths = remove_cvref_t; - using BottomTensorDesc = typename BottomTensorView::TensorDesc; - using DataType = typename BottomTensorView::DataType; - - static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); - - static_assert(ck_tile::is_known_at_compile_time::value, - "wrong! lengths should be static"); - - using BottomTensorIndex = array; + using Base = + tile_window_base, + BottomTensorView_, + WindowLengths_>; CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default; CK_TILE_DEVICE constexpr tile_window_with_static_lengths( - const BottomTensorView& bottom_tensor_view, - const WindowLengths& window_lengths, - const BottomTensorIndex& window_origin) - : bottom_tensor_view_{bottom_tensor_view}, - window_lengths_{window_lengths}, - window_origin_{window_origin} + const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::WindowLengths& window_lengths, + const typename Base::BottomTensorIndex& window_origin) { + this->window_origin_ = window_origin; + this->window_lengths_ = window_lengths; + this->bottom_tensor_view_ = bottom_tensor_view; } - - CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } - - CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } - - CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } - - CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } - - CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) - { - window_origin_ = new_window_origin; - } - - CK_TILE_DEVICE constexpr void - set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) - { - bottom_tensor_view_.buf_.p_data_ = data; - } - - // move window-origin - CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; } - - // this is the bottom tensor view - // [x0', x1', ...] ==> [offset] - BottomTensorView bottom_tensor_view_; - - // - WindowLengths window_lengths_; - - // origin ([x0', x1', ...]) of window on bottom tensor - BottomTensorIndex window_origin_; }; template diff --git a/include/ck_tile/core/tensor/tile_window_base.hpp b/include/ck_tile/core/tensor/tile_window_base.hpp new file mode 100644 index 0000000000..89a928a53c --- /dev/null +++ b/include/ck_tile/core/tensor/tile_window_base.hpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/** + * @brief This class provides description of tile windowed view on the device memory. + * + * @note This class does not provide any functions to read or modify device memory. + * + * @tparam BottomTensorView_ Class describing & holding device tensor memory. + * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. + */ +template +struct tile_window_base +{ + + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + using DataType = remove_cvref_t; + + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + using BottomTensorIndex = array; + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + + // Delegate to child if it implements extra logic + static_cast(this)->set_window_origin_extended(new_window_origin); + } + // Default no-op; can be overridden in child + CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex&) {} + + CK_TILE_DEVICE constexpr void + set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) + { + bottom_tensor_view_.buf_.p_data_ = data; + } + + // move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) + { + window_origin_ += step; + + // Delegate to child if it implements extra movement logic + static_cast(this)->move_extended(step); + } + + // Default no-op; can be overridden in child + CK_TILE_DEVICE void move_extended(const BottomTensorIndex&) {} + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; + + WindowLengths window_lengths_; + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; +}; + +template +struct tile_window_with_tile_dstr_base + : public tile_window_base +{ + using TileDstr = remove_cvref_t; + using TileWindowBase = tile_window_base; + + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + + static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); + + static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); + static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + + using AdaptorTopIndex = array; + // using BottomTensorIndex = array; + + using WindowAdaptorCoord = + decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); + + using BottomTensorCoord = decltype(make_tensor_coordinate( + typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{})); + + static_assert(TileDstr::is_static(), "wrong!"); + static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), + "wrong! inconsistent # of diemsnions"); + + CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } + CK_TILE_HOST_DEVICE void init_raw() { this->bottom_tensor_view_.init_raw(); } + + CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() + { + return TileDstr::is_static(); + } + + // move thread's window adaptor coordinate and bottom tensor coordinate + // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + template + CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( + WindowAdaptorCoord& window_adaptor_thread_coord, + BottomTensorCoord& bottom_tensor_thread_coord, + const ATopIndex& idx_diff_adaptor_top) const + { + array idx_diff_adaptor_bottom; + + move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + window_adaptor_thread_coord, + idx_diff_adaptor_top, + idx_diff_adaptor_bottom); + + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_coord, + idx_diff_adaptor_bottom); + } + + struct Traits + { + public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = + get_vector_dim_y_scalar_per_vector().template at<1>(); + using vector_t = + thread_buffer; + + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto thread_tensor_lengths_ys = + to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + }; + + // return vector dimension among [y0, y1, ...] + CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() + { + // bottom tensor top dimension vector lengths and strides + const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = + TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); + + // window vector lengths/strides + const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; + const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; + + // window adaptor [p0, p1, ..., y0, y1, ...] + array window_adaptor_vector_lengths{ + -1}; + array window_adaptor_vector_strides{ + -1}; + + constexpr auto window_adaptor_bottom_dims = + WindowAdaptor::get_bottom_dimension_hidden_ids(); + + set_container_subset(window_adaptor_vector_lengths, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_lengths); + set_container_subset(window_adaptor_vector_strides, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_strides); + + const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = + WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( + window_adaptor_vector_lengths, window_adaptor_vector_strides); + + // [y0, y1, ...] + constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; + + return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), + get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); + } + + CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; } + // Tile tensor distribution, which contains: + // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] + // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] + TileDstr tile_dstr_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 5ecaf5ca17..f11610d658 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -13,6 +13,7 @@ #include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -37,171 +38,48 @@ namespace ck_tile { // TODO: if using this struct, better use load_raw()/store_raw(), can control // the the immediate offset on the fly // space-filing-curve is non-snaked here! -// +// This struct inherits from tile_window_with_tile_dstr_base, which is an intermediary base class +// with the ultimate parent class being tile_window_base. template struct tile_window_linear + : public tile_window_with_tile_dstr_base, + BottomTensorView_, + WindowLengths_, + StaticTileDistribution_> { + using Base = tile_window_with_tile_dstr_base, + BottomTensorView_, + WindowLengths_, + StaticTileDistribution_>; - using BottomTensorView = remove_reference_t; - using WindowLengths = remove_cvref_t; - using TileDstr = remove_cvref_t; - - using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; - using BottomTensorDesc = typename BottomTensorView::TensorDesc; - - using DataType = remove_cvref_t; using LinearBottomDims = remove_cvref_t; - static_assert(LinearBottomDims::size() == BottomTensorView::get_num_of_dimension()); - - static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); - static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); - - static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); - static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + static_assert(LinearBottomDims::size() == Base::BottomTensorView::get_num_of_dimension()); static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; - // TODO: check WindowLengths and StaticTileDistribution are consistent - - static_assert(ck_tile::is_known_at_compile_time::value, - "wrong! lengths should be static"); - static_assert(TileDstr::is_static(), "wrong!"); - - static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), - "wrong! inconsistent # of diemsnions"); - - using AdaptorTopIndex = array; - using BottomTensorIndex = array; - - using WindowAdaptorCoord = - decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); - - using BottomTensorCoord = - decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); - struct traits { - private: - // return vector dimension among [y0, y1, ...] - CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() - { - // bottom tensor top dimension vector lengths and strides - const auto [bottom_tensor_top_dim_vector_lengths, - bottom_tensor_top_dim_vector_strides] = - BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); - - // window vector lengths/strides - const auto window_adaptor_bottom_dim_vector_lengths = - bottom_tensor_top_dim_vector_lengths; - const auto window_adaptor_bottom_dim_vector_strides = - bottom_tensor_top_dim_vector_strides; - - // window adaptor [p0, p1, ..., y0, y1, ...] - array - window_adaptor_vector_lengths{-1}; - array - window_adaptor_vector_strides{-1}; - - constexpr auto window_adaptor_bottom_dims = - WindowAdaptor::get_bottom_dimension_hidden_ids(); - - set_container_subset(window_adaptor_vector_lengths, - window_adaptor_bottom_dims, - window_adaptor_bottom_dim_vector_lengths); - set_container_subset(window_adaptor_vector_strides, - window_adaptor_bottom_dims, - window_adaptor_bottom_dim_vector_strides); - - const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = - WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( - window_adaptor_vector_lengths, window_adaptor_vector_strides); - - // [y0, y1, ...] - constexpr auto y_dims = - typename arithmetic_sequence_gen::type{}; - - return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), - get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); - } - - static constexpr auto get_vector_dim_y_scalar_per_vector() - { - const auto [ys_vector_lengths, ys_vector_strides] = - get_window_adaptor_ys_safe_vector_length_strides(); - - index_t VectorDimY_ = 0; - index_t ScalarPerVector_ = 1; - - for(index_t i = 0; i < NDimY; ++i) - { - if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) - { - ScalarPerVector_ = ys_vector_lengths[i]; - VectorDimY_ = i; - } - } - - return make_tuple(VectorDimY_, ScalarPerVector_); - } - - public: - static constexpr index_t PackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); - static constexpr index_t ScalarPerVector = - get_vector_dim_y_scalar_per_vector().template at<1>(); - - using vector_t = thread_buffer; - - private: - static constexpr auto scalars_per_access_ = [] { - constexpr auto scalars_per_access_arr = generate_array( - [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); - - /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() - constexpr auto NDimY_ = NDimY; - - return TO_SEQUENCE(scalars_per_access_arr, NDimY_); - }(); - - static constexpr auto get_space_filling_curve() - { - constexpr auto thread_tensor_lengths_ys = - to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths()); - - // FIXME: need logic to judge dim access order - using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; - - return space_filling_curve{}; - } - - public: - using SFC_Ys = decltype(get_space_filling_curve()); - - static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); - - static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); - private: static constexpr auto get_num_non_linear_access() { - constexpr auto sfc_access_lens = SFC_Ys::access_lengths; - using ys_to_rhs_major = - typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; + using ys_to_rhs_major = typename decltype( + typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear = [&]() { index_t cnt = 1; - static_for<0, NDimY, 1>{}([&](auto i_dim_y) { + static_for<0, Base::NDimY, 1>{}([&](auto i_dim_y) { constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y]; constexpr auto target_h_dim = number{}; // no r dim here! if constexpr(LinearBottomDims{}[target_h_dim] == 0) @@ -230,20 +108,20 @@ struct tile_window_linear // -> prefixsum : seqneuce<0, 2, 4, 6, 8> static constexpr auto get_non_linear_access_map() { - constexpr auto sfc_access_lens = SFC_Ys::access_lengths; - using ys_to_rhs_major = - typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; + using ys_to_rhs_major = typename decltype( + typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear_map = [&]() { - array m_{0}; + array m_{0}; index_t cumulative_len_ = 1; index_t cumulative_non_linear_len_ = 1; - static_for<0, NDimY, 1>{}([&](auto i_y) { - constexpr auto i_dim_y = number{}; // from right to left + static_for<0, Base::NDimY, 1>{}([&](auto i_y) { + constexpr auto i_dim_y = number{}; // from right to left constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y]; constexpr auto target_h_dim = number{}; // no r dim here! constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim]; - array current_m_{0}; + array current_m_{0}; constexpr auto current_len_ = sfc_access_lens[i_dim_y]; // copy cumulative length as current pattern @@ -266,13 +144,12 @@ struct tile_window_linear return m_; }(); - return TO_SEQUENCE(non_linear_map, NumAccess); + return TO_SEQUENCE(non_linear_map, Base::Traits::NumAccess); } static constexpr auto get_non_linear_access_histogram() { constexpr auto m_ = get_non_linear_access_map(); - // m_.foo(); constexpr auto r_ = typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{}; @@ -296,7 +173,7 @@ struct tile_window_linear using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum()); }; - static constexpr index_t NumAccess = traits::NumAccess; + static constexpr index_t NumAccess = Base::Traits::NumAccess; static constexpr index_t NumAccess_NonLinear = traits::NumAccess_NonLinear; using AccessMap_NonLinear = typename traits::AccessMap_NonLinear; using AccessHistogram_NonLinear = typename traits::AccessHistogram_NonLinear; @@ -304,30 +181,31 @@ struct tile_window_linear CK_TILE_DEVICE constexpr tile_window_linear() = default; - CK_TILE_DEVICE constexpr tile_window_linear(const BottomTensorView& bottom_tensor_view, - const WindowLengths& window_lengths, - const BottomTensorIndex& window_origin, - const TileDstr& tile_distribution) - : bottom_tensor_view_{bottom_tensor_view}, - window_lengths_{window_lengths}, - window_origin_{window_origin}, - tile_dstr_{tile_distribution}, - cached_coords_{}, - cached_flags_{} + CK_TILE_DEVICE constexpr tile_window_linear( + const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::WindowLengths& window_lengths, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution) + : cached_coords_{}, cached_flags_{} { + this->bottom_tensor_view_ = bottom_tensor_view; + this->window_lengths_ = window_lengths; + this->window_origin_ = window_origin; + this->tile_dstr_ = tile_distribution; auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(make_tuple(get_warp_id(), get_lane_id()), - generate_tuple([&](auto) { return number<0>{}; }, number{}))); + container_concat( + make_tuple(get_warp_id(), get_lane_id()), + generate_tuple([&](auto) { return number<0>{}; }, number{}))); - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); // future load/store() calls (might allocate more registers) - using SFC_Ys = typename traits::SFC_Ys; + using SFC_Ys = typename Base::Traits::SFC_Ys; static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto non_linear_id = number{}; @@ -343,16 +221,16 @@ struct tile_window_linear // cached flag is independent from non-linear-coord // but need be updated in move_tile, with proper dims cached_flags_(i_access) = coordinate_has_valid_offset_assuming_top_index_is_valid( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp); + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp); if constexpr(i_access != (NumAccess - 1)) { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord_tmp, bottom_tensor_thread_coord_tmp, idx_diff_ps_ys); @@ -360,54 +238,13 @@ struct tile_window_linear }); } - CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } - - CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() - { - return TileDstr::is_static(); - } - - CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } - - CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } - - CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } - - CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } - - CK_TILE_DEVICE constexpr void - set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) - { - bottom_tensor_view_.buf_.p_data_ = data; - } - - // move thread's window adaptor coordinate and bottom tensor coordinate - // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] - template - CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( - WindowAdaptorCoord& window_adaptor_thread_coord, - BottomTensorCoord& bottom_tensor_thread_coord, - const ATopIndex& idx_diff_adaptor_top) const - { - array idx_diff_adaptor_bottom; - - move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), - window_adaptor_thread_coord, - idx_diff_adaptor_top, - idx_diff_adaptor_bottom); - - move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), - bottom_tensor_thread_coord, - idx_diff_adaptor_bottom); - } - template CK_TILE_DEVICE static constexpr auto get_bottom_linear_coordinate(number) { - using SFC_Ys = typename traits::SFC_Ys; + using SFC_Ys = typename Base::Traits::SFC_Ys; constexpr auto idx_ys = SFC_Ys::get_index(number{}); - using ys_to_rhs_major = - typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = typename decltype( + typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto modified_idx_ys = generate_tuple( [&](auto i_dim_y) { @@ -422,9 +259,9 @@ struct tile_window_linear return number{}; } }, - number{}); + number{}); - constexpr auto adaptor_ = TileDstr{}.get_ps_ys_to_xs_adaptor(); + constexpr auto adaptor_ = typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor(); constexpr auto idx_ = container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys); @@ -441,8 +278,8 @@ struct tile_window_linear { // this case usually is a LDS window, everything is known at compile tile. // we directly use BottomTensorView transform to compute the offset, in case padding - auto bottom_tensor_coord = - make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord); + auto bottom_tensor_coord = make_tensor_coordinate( + typename Base::BottomTensorView{}.get_tensor_descriptor(), linear_coord); return bottom_tensor_coord.get_offset(); } else @@ -453,7 +290,7 @@ struct tile_window_linear // since that would introduce runtime length (so can't use linear offset) constexpr index_t linear_offset = [&]() { constexpr auto x_idx_ = linear_coord; - constexpr auto x_len_ = TileDstr{}.get_lengths(); + constexpr auto x_len_ = typename Base::TileDstr{}.get_lengths(); static_assert(x_idx_.size() == x_len_.size()); constexpr index_t x_dims_ = x_idx_.size(); index_t cu_stride_ = 1; @@ -469,17 +306,16 @@ struct tile_window_linear } } - CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } - template CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; - auto dst_tensor = make_static_distributed_tensor(tile_dstr); + auto dst_tensor = + make_static_distributed_tensor(tile_dstr); auto issue = [&](auto i_access_) { constexpr auto IAccess = number{}; @@ -492,35 +328,29 @@ struct tile_window_linear // read from bottom tensor const vector_t vec_value = - get_bottom_tensor_view().template get_vectorized_elements( + this->get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, bool_constant{}); -#if 1 + // data index [y0, y1, ...] constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); // write into distributed tensor - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; + return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j) + : idx_diff_ys[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j / traits::PackedSize]; + dst_tensor.get_thread_buffer().template at() = vec_value.template get_as< + typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize]; }); -#else - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); - static_assert(d % traits::ScalarPerVector == 0); - - dst_tensor.get_thread_buffer().template get_as()( - number{}) = bit_cast(vec_value); -#endif }; WINDOW_DISPATCH_ISSUE(); @@ -533,10 +363,10 @@ struct tile_window_linear number = {}, bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // auto dst_tensor = make_static_distributed_tensor(tile_dstr); @@ -551,35 +381,28 @@ struct tile_window_linear // read from bottom tensor const vector_t vec_value = - get_bottom_tensor_view().template get_vectorized_elements( + this->get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, bool_constant{}); -#if 1 // data index [y0, y1, ...] constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); // write into distributed tensor - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; + return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j) + : idx_diff_ys[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j / traits::PackedSize]; + dst_tensor.get_thread_buffer().template at() = vec_value.template get_as< + typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize]; }); -#else - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); - static_assert(d % traits::ScalarPerVector == 0); - - dst_tensor.get_thread_buffer().template get_as()( - number{}) = bit_cast(vec_value); -#endif }; WINDOW_DISPATCH_ISSUE(); @@ -596,15 +419,17 @@ struct tile_window_linear bool_constant = {}, bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; static constexpr index_t YElementSize = - TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); - static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0); + typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); + static_assert(YElementSize % (Base::Traits::PackedSize * Base::Traits::ScalarPerVector) == + 0); using vectorized_tbuf = - array; + array; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; auto& dst_vec_tbuf = reinterpret_cast(dst_tensor.get_thread_buffer()); @@ -612,7 +437,7 @@ struct tile_window_linear constexpr auto IAccess = number{}; constexpr auto pre_nop_ = [&]() { if constexpr(pre_nop && i_access_ == 0 && - BottomTensorView::buffer_view::get_address_space() == + Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global) return bool_constant{}; else @@ -628,11 +453,11 @@ struct tile_window_linear constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) / - traits::PackedSize; - static_assert(d % traits::ScalarPerVector == 0); + Base::Traits::PackedSize; + static_assert(d % Base::Traits::ScalarPerVector == 0); - get_bottom_tensor_view().template get_vectorized_elements_raw( - dst_vec_tbuf.template at(), + this->get_bottom_tensor_view().template get_vectorized_elements_raw( + dst_vec_tbuf.template at(), bottom_tensor_thread_coord, linear_offset /**/, bottom_tensor_flag, @@ -663,7 +488,7 @@ struct tile_window_linear // currently we only support everything is non linear dim // actually it's not performant if we have linear dim(e.g. fast changing) static_assert(NumAccess_NonLinear == NumAccess); - static_assert(BottomTensorView::buffer_view::get_address_space() == + static_assert(Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global); // issues * warps * lanes @@ -689,7 +514,7 @@ struct tile_window_linear const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); m0_set_with_memory(m0_init_value); // This should be wave independent - using vector_t = typename traits::vector_t; + using vector_t = typename Base::Traits::vector_t; LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; @@ -708,7 +533,7 @@ struct tile_window_linear auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements_raw( + this->get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_); // move thread coordinate @@ -732,7 +557,7 @@ struct tile_window_linear // currently we only support everything is non linear dim // actually it's not performant if we have linear dim(e.g. fast changing) static_assert(NumAccess_NonLinear == NumAccess); - static_assert(BottomTensorView::buffer_view::get_address_space() == + static_assert(Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global); // issues * warps * lanes @@ -757,7 +582,7 @@ struct tile_window_linear const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - using vector_t = typename traits::vector_t; + using vector_t = typename Base::Traits::vector_t; // TODO: we force CK_TILE_LDS_ADDR CK_TILE_LDS_ADDR LdsDataType* smem = @@ -771,7 +596,7 @@ struct tile_window_linear auto bottom_tensor_flag = cached_flags_[IAccess]; // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements( + this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, 0, @@ -789,15 +614,16 @@ struct tile_window_linear } template - CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, number = {}, bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { @@ -812,22 +638,23 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - vec_value.template get_as()(j / traits::PackedSize) = + vec_value.template get_as()(j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template set_vectorized_elements( + this->get_bottom_tensor_view().template set_vectorized_elements( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, @@ -839,13 +666,15 @@ struct tile_window_linear } template - CK_TILE_DEVICE void store_raw(const static_distributed_tensor& dstr_tensor, - number = {}) const + CK_TILE_DEVICE void + store_raw(const static_distributed_tensor& + dstr_tensor, + number = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; static constexpr bool oob_conditional_check = true; // loop over thread tensor space [y0, y1, ...] @@ -861,20 +690,21 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; - vec_value.template get_as()(j / traits::PackedSize) = + Base::Traits::PackedSize; + vec_value.template get_as()(j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view() + this->get_bottom_tensor_view() .template set_vectorized_elements_raw( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value); }; @@ -883,15 +713,17 @@ struct tile_window_linear } template - CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, - number = {}, - bool_constant = {}) const + CK_TILE_DEVICE void + update(const static_distributed_tensor& + dstr_tensor, + number = {}, + bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { @@ -907,22 +739,24 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - vec_value.template get_as()(j / traits::PackedSize) = + vec_value.template get_as()( + j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template update_vectorized_elements( + this->get_bottom_tensor_view().template update_vectorized_elements( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, @@ -934,16 +768,18 @@ struct tile_window_linear } template - CK_TILE_DEVICE void update_raw(const static_distributed_tensor& dstr_tensor, - number = {}, - bool_constant = {}, - bool_constant = {}) const + CK_TILE_DEVICE void + update_raw(const static_distributed_tensor& + dstr_tensor, + number = {}, + bool_constant = {}, + bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { @@ -959,22 +795,24 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - vec_value.template get_as()(j / traits::PackedSize) = + vec_value.template get_as()( + j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template update_vectorized_elements_raw( + this->get_bottom_tensor_view().template update_vectorized_elements_raw( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, @@ -985,14 +823,10 @@ struct tile_window_linear WINDOW_DISPATCH_ISSUE(); } - - // move thread's botom tensor coordiante - // [x0', x1', ... ] ==> [offset] - // also move window-origin - CK_TILE_DEVICE void move(const BottomTensorIndex& step) + // *_extended() functions acts like a virtual function with a default implementation exisiting + // in the base class + CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step) { - window_origin_ += step; - static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto IAccess = number{}; constexpr auto non_linear_id = number{}; @@ -1001,7 +835,7 @@ struct tile_window_linear if constexpr(need_update_non_linear_coord) { - move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), cached_coords_(non_linear_id), step); } @@ -1010,30 +844,29 @@ struct tile_window_linear auto tmp_coords = cached_coords_[non_linear_id]; constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess); move_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord); + this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord); cached_flags_(IAccess) = coordinate_has_valid_offset_assuming_top_index_is_valid( - bottom_tensor_view_.get_tensor_descriptor(), tmp_coords); + this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords); }); } - CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&) { - window_origin_ = new_window_origin; - auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - TileDstr{}.get_ps_ys_to_xs_adaptor(), - container_concat(make_tuple(get_warp_id(), get_lane_id()), - generate_tuple([&](auto) { return number<0>{}; }, number{}))); + typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor(), + container_concat( + make_tuple(get_warp_id(), get_lane_id()), + generate_tuple([&](auto) { return number<0>{}; }, number{}))); - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = - window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); // future load/store() calls (might allocate more registers) - using SFC_Ys = typename traits::SFC_Ys; + using SFC_Ys = typename Base::Traits::SFC_Ys; static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto non_linear_id = number{}; @@ -1049,10 +882,10 @@ struct tile_window_linear { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord_tmp, bottom_tensor_thread_coord_tmp, idx_diff_ps_ys); @@ -1060,26 +893,9 @@ struct tile_window_linear }); } - CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } - - // this is the bottom tensor view - // [x0', x1', ...] ==> [offset] - BottomTensorView bottom_tensor_view_; - - // - WindowLengths window_lengths_; - - // origin ([x0', x1', ...]) of window on bottom tensor - BottomTensorIndex window_origin_; - - // Tile tensor distribution, which contains: - // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] - // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] - TileDstr tile_dstr_; - // this contains: - array cached_coords_; - array cached_flags_; + array cached_coords_; + array cached_flags_; }; #undef WINDOW_DISPATCH_ISSUE