diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index 4fcea9642d..b187b71830 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -1178,6 +1178,15 @@ struct reverse_slice_sequence_impl, sequence, sequence, Slice // clang-format off // input a sequence(with optional mask), and the SliceSize : size per slice // output the sequence each slice, and number of slices +// the length count for slice size is from right to left(reverse slice) +// or we can say, find the greatest common divider(gcd) from right to left, for the slice length +// +// e.g. <2, 8, 4>, slice length = 16 +// step-1: we take the right most <*, *, 4>, remaining 16/4=4 +// step-2: we only need 4 out of 8, of the midden dim, hence <*, 4, 4> +// step-3: since nonthing remain, so the first dim we only need 1, hence<1, 4, 4> +// => we got <1, 4, 4> as length for each slice +// => total number of slice = <2, 8, 4> / <1, 4, 4> = <2, 2, 1> // // e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0 // <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2 @@ -1197,7 +1206,7 @@ struct reverse_slice_sequence_impl, sequence, sequence, Slice // // return tuple, slice_index is at which index will start // have split slices (right -> left) -// or the first index that sliced length is different from the original length +// or the first index (right -> left) that sliced length is different from the original length // clang-format on template ::type{}) { static_assert(Seq::size() == Mask::size()); + static_assert(SliceSize != 0, "slice size zero is invalid"); + static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) % + SliceSize == + 0, + "slice size can't evenly divide input sizes"); using sliced_type = impl::reverse_slice_sequence_impl - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length) +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one) // Y P P Y P Y P Y // => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK // |--> slice along this Y dim, is the first dim of X1, totally 4 slices // // X0 X1 -// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length) +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one) // Y P P Y P Y P Y // => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK // |--> slice along this Y dim, the P dim is 1 in the left, so is OK // totally 16 slices // // X0 X1 -// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length) +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one) // Y P P Y P Y P Y // => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail // |--> slice along this P dim, will split threads, not supported // // X0 X1 -// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length) +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one) // Y P P Y P Y P Y // => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK // |--> slice along this Y dim, but this Y sim need to split into 2 @@ -577,11 +577,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( using Encoding = decltype(Distribution::get_static_tile_distribution_encoding()); static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds)); + static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r"); - constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins; + constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h(); + + constexpr auto x_slice_ends_ = generate_sequence_v2( + [&](auto i) { + if constexpr(x_slice_ends[i] == -1) + { + // -1 means till the end + constexpr auto x_length_ = + container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{}); + return x_length_; + } + else + { + return x_slice_ends[i]; + } + }, + number{}); + + constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins; + + constexpr auto x_slice_lengths_without_p = generate_sequence_v2( + [&](auto i) constexpr { + constexpr auto len_ = x_slice_lengths[i]; + static_assert(len_ % p_len_over_h[i] == 0, + "slice length must be dividable by p_len_over_h"); + return number{}; + }, + number{}); constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum(); - constexpr auto src_y_info = Encoding::detail::get_sorted_y_info(); + constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info(); constexpr auto src_y_dims = src_y_info[number<0>{}]; constexpr auto src_y_maps = src_y_info[number<1>{}]; constexpr auto src_y_prefix_sum = src_y_info[number<2>{}]; @@ -590,14 +618,15 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( { auto y_slice_sorted_origins = make_zero_multi_index(); auto y_slice_lengths = Encoding::detail::ys_lengths_; + constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks(); // This lambda will modify some value outside, so c++ will not treat return value as // constexpr // TODO: ugly auto new_h_lengths = transform_tuples( [&](auto h_len, auto id) { - constexpr auto sliced_h = - reverse_slice_sequence(h_len, number{}); + constexpr auto sliced_h = reverse_slice_sequence( + h_len, number{}, y_to_h_masks[id]); constexpr auto sliced_h_lens = sliced_h[number<0>{}]; constexpr auto sliced_h_index = sliced_h[number<2>{}]; @@ -605,26 +634,39 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( // update y_slice_lengths constexpr auto uniformed_h_index = sliced_h_index + number{}; constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index); + constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1]; static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(), "not sliced at y dim, please check"); - static_for<0, sliced_h_index + 1, 1>{}([&](auto i) { - y_slice_lengths(src_y_maps[found_y_index - i]) = - sliced_h_lens[sliced_h_index - i]; - }); + { + constexpr auto sliced_y_to_h_lens = + pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]); + constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size(); + static_for<0, sliced_y_to_h_dims, 1>{}([&](auto i) { + y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) = + sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i]; + }); + } // TODO: add validations not across p dim // NOTE: this y_origin is for all dims, not only current dim // will later use pick to select target dim constexpr auto y_origin = [&]() { - constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len); - auto h_origin_ = make_zero_multi_index(); - h_trans.calculate_lower_index(h_origin_, sequence{}); + // can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered + constexpr auto y_to_h_len = + pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]); + constexpr auto y_to_h_dims = y_to_h_len.size(); + + constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len); + auto h_origin_ = make_zero_multi_index(); + constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id]; + h_trans.calculate_lower_index(h_origin_, sequence{}); auto y_origin_ = make_zero_multi_index(); - static_for<0, sliced_h_index + 1, 1>{}([&](auto i) { - y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i]; + + static_for<0, y_to_h_dims, 1>{}([&](auto i) { + y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i]; }); return y_origin_; }(); diff --git a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp index 7b1e952025..30cd698595 100644 --- a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp +++ b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp @@ -255,33 +255,107 @@ struct tile_distribution_encoding } }(); - // e.g. tuple, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8> - CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum() + CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_h_dim_lengths() { - // // e.g. tuple, seq<4, 1, 4, 2, 4>> --> seq<3, 5> constexpr auto uniformed_h_dim_lengths = generate_sequence_v2( [&](auto i) { - constexpr index_t size = HsLengthss{}[i].size(); - return number{}; + constexpr index_t size_ = HsLengthss{}[i].size(); + return number{}; }, number{}); + return uniformed_h_dim_lengths; + } + // note: this function only count the p dim length along h, not r + CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_p_dim_lengths_over_h() + { + // e.g. tuple, seq<1, 2, 8, 4, 4>> + // Y P Y Y P Y P Y + // | | | + // v v v + // return : seq<4, 2 * 4> => seq<4, 8> + constexpr auto uniformed_ps_to_rhss_major_ = + unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_); + constexpr auto uniformed_ps_to_rhss_minor_ = + unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_); + + constexpr auto p_len_ = [&]() { + array len_{1}; + static_for<0, NDimX, 1>{}([&](auto idim_x_) { + constexpr auto major_ = number{}; // RDim + static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) { + if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_]) + { + constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_]; + constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_]; + len_[idim_x_] *= h_length_; + } + }); + }); + return len_; + }(); + constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX); + return p_len_over_h_seq_; + } + + // + // R: seq<3>, H: tuple, seq<4, 1, 4, 2, 4>> + // => return seq<1, 3, 5> + // R: seq<>, H: tuple, seq<16, 8, 8>> + // => return seq<0, 2, 3> + CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_rh_dim_lengths() + { + constexpr auto uniformed_rh_dim_lengths = + merge_sequences(sequence{} /*for R dims*/, get_uniformed_h_dim_lengths()); + + return uniformed_rh_dim_lengths; + } + + // e.g. tuple, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8> + CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum() + { // <0, len_d0, len_d0+len_d1, ...> // e.g. seq<3, 5> --> seq<0, 3, 8> - constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths); + constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths()); return h_dim_prefix_sum; } - CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h() + CK_TILE_HOST_DEVICE static constexpr auto get_rh_dim_lengths_prefix_sum() + { + // <0, len_d0, len_d0+len_d1, ...> + // e.g. seq<3, 5> --> seq<0, 3, 8> + constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths()); + + return rh_dim_prefix_sum; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_p_to_h() + { + // tuple, seq> -> seq + constexpr auto uniformed_ps_to_rhss_major_ = + unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_); + constexpr auto uniformed_ps_to_rhss_minor_ = + unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_); + + constexpr auto all_ps_2_rhss = transform_sequences( + [](auto major, auto minor) constexpr { + constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum(); + return rh_dim_prefix_sum.at(major) + minor; + }, + uniformed_ps_to_rhss_major_, + uniformed_ps_to_rhss_minor_); + + return all_ps_2_rhss; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_rh() { constexpr auto all_ys_2_rhss = transform_sequences( [](auto major, auto minor) constexpr { - // <0, 0, len_d0, len_d0+len_d1, ...> - constexpr auto x_dim_prefix_sum = merge_sequences( - sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum()); - return x_dim_prefix_sum.at(major) + minor; + constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum(); + return rh_dim_prefix_sum.at(major) + minor; }, Ys2RHsMajor{}, Ys2RHsMinor{}); @@ -289,6 +363,45 @@ struct tile_distribution_encoding return all_ys_2_rhss; } + CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h() + { + // TODO: Y can't point to R + constexpr auto all_ys_2_rhss = transform_sequences( + [](auto major, auto minor) constexpr { + constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum(); + return rh_dim_prefix_sum.at(major) + minor - NDimR; + }, + Ys2RHsMajor{}, + Ys2RHsMinor{}); + + return all_ys_2_rhss; + } + + // return tuple of seq + CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks() + { + constexpr auto masks_ = generate_tuple( + [&](auto i) { + constexpr auto size_ = HsLengthss{}[i].size(); + constexpr auto current_y_to_h_mask_ = [&]() { + array m_{0}; + // TODO: we loop over all y for each h dim + for(auto j = 0; j < NDimY; j++) + { + if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/) + { + m_[Ys2RHsMinor{}[j]] = 1; + } + } + return m_; + }(); + + return TO_SEQUENCE(current_y_to_h_mask_, size_); + }, + number{}); + return masks_; + } + // return tuple template CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq) @@ -305,7 +418,8 @@ struct tile_distribution_encoding return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum); } - CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info() + // Note here y_to_h does not count R dim! + CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_to_h_info() { return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum()); } diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 57afb5cbb5..5d05243238 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) add_subdirectory(gemm_multi_d) add_subdirectory(data_type) +add_subdirectory(slice_tile) diff --git a/test/ck_tile/slice_tile/CMakeLists.txt b/test/ck_tile/slice_tile/CMakeLists.txt new file mode 100644 index 0000000000..d0d1a4ee00 --- /dev/null +++ b/test/ck_tile/slice_tile/CMakeLists.txt @@ -0,0 +1 @@ +add_test_executable(test_slice_tile test_slice_tile.cpp) \ No newline at end of file diff --git a/test/ck_tile/slice_tile/test_slice_tile.cpp b/test/ck_tile/slice_tile/test_slice_tile.cpp new file mode 100644 index 0000000000..57770d3bf6 --- /dev/null +++ b/test/ck_tile/slice_tile/test_slice_tile.cpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core.hpp" +#include + +// clang-format off +template, + typename SliceEnd_ = ck_tile::sequence<64, 16>, + typename Y_Origin_ = ck_tile::sequence<0, 0, 0, 0>> +void test_slice_distribution_from_x_case_0(SliceStart_ = {}, SliceEnd_={}, Y_Origin_ = {}) +{ + // slice length [-1, 16] + using namespace ck_tile; + constexpr auto r = detail::slice_distribution_from_x( + make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence<2, 2, 1, 4, 4>>, + // Y P P Y P Y P Y + tuple, sequence<2, 1>>, + tuple, sequence<3, 2>>, + sequence<1, 2, 2, 2>, + sequence<0, 0, 2, 4>>{}), + SliceStart_{}, + SliceEnd_{}); + + using sliced_dist_enc = remove_cvref_t{}].get_static_tile_distribution_encoding())>; + using target_dist_enc = tile_distribution_encoding, + tuple, sequence<1, 2, 1, 4, 2>>, + // Y P P Y P Y P Y + tuple, sequence<2, 1>>, + tuple, sequence<3, 2>>, + sequence<1, 2, 2, 2>, + sequence<0, 0, 2, 4>>; + + static_assert(std::is_same_v); + + using sliced_y_origins = remove_cvref_t{}])>; + using sliced_y_lengths = remove_cvref_t{}])>; + static_assert(std::is_same_v); + static_assert(std::is_same_v>); +} + +template, + typename SliceEnd_ = ck_tile::sequence<16, 16>, + typename Y_Origin_ = ck_tile::sequence<0, 0, 0, 0, 0>> +void test_slice_distribution_from_x_case_1(SliceStart_ = {}, SliceEnd_={}, Y_Origin_ = {}) +{ + // slice length [16, 16] + using namespace ck_tile; + constexpr auto r = detail::slice_distribution_from_x( + make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence<2, 4, 2, 8, 2>>, + // Y P Y Y P Y Y P + tuple, sequence<2, 2>>, + tuple, sequence<4, 1>>, + sequence<1, 1, 2, 2, 2>, + sequence<0, 2, 0, 2, 3>>{}), + SliceStart_{}, + SliceEnd_{}); + + using sliced_dist_enc = remove_cvref_t{}].get_static_tile_distribution_encoding())>; + using target_dist_enc = tile_distribution_encoding, + tuple, sequence<1, 4, 1, 2, 2>>, + // Y P Y Y P Y Y P + tuple, sequence<2, 2>>, + tuple, sequence<4, 1>>, + sequence<1, 1, 2, 2, 2>, + sequence<0, 2, 0, 2, 3>>; + + static_assert(std::is_same_v); + + using sliced_y_origins = remove_cvref_t{}])>; + using sliced_y_lengths = remove_cvref_t{}])>; + static_assert(std::is_same_v); + static_assert(std::is_same_v>); +} + +template, + typename SliceEnd_ = ck_tile::sequence<12, 48>, + typename Y_Origin_ = ck_tile::sequence<0, 0, 0, 0, 0>> +void test_slice_distribution_from_x_case_2(SliceStart_ = {}, SliceEnd_={}, Y_Origin_ = {}) +{ + // slice length [12, 48] + using namespace ck_tile; + constexpr auto r = detail::slice_distribution_from_x( + make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence<2, 2, 1, 4, 3, 4>>, + // Y P Y Y P, Y, P P, Y + tuple, sequence<2, 2, 2>>, + tuple, sequence<4, 1, 3>>, + sequence<1, 2, 1, 2, 2>, + sequence<2, 0, 0, 5, 2>>{}), + SliceStart_{}, + SliceEnd_{}); + + using sliced_dist_enc = remove_cvref_t{}].get_static_tile_distribution_encoding())>; + using target_dist_enc = tile_distribution_encoding, + tuple, sequence<1, 2, 1, 4, 3, 2>>, + // Y P Y Y P, Y, P P, Y + tuple, sequence<2, 2, 2>>, + tuple, sequence<4, 1, 3>>, + sequence<1, 2, 1, 2, 2>, + sequence<2, 0, 0, 5, 2>>; + + static_assert(std::is_same_v); + + using sliced_y_origins = remove_cvref_t{}])>; + using sliced_y_lengths = remove_cvref_t{}])>; + static_assert(std::is_same_v); + static_assert(std::is_same_v>); +} + +void test_slice_distribution_from_x() +{ + using namespace ck_tile; + + test_slice_distribution_from_x_case_0(sequence< 0, 0>{}, sequence<-1, 16>{}, sequence<0, 0, 0, 0>{}); + test_slice_distribution_from_x_case_0(sequence< 0, 16>{}, sequence<-1, 32>{}, sequence<0, 0, 0, 2>{}); + test_slice_distribution_from_x_case_0(sequence< 0, 32>{}, sequence<-1, 48>{}, sequence<0, 1, 0, 0>{}); + test_slice_distribution_from_x_case_0(sequence< 0, 48>{}, sequence<-1, 64>{}, sequence<0, 1, 0, 2>{}); + + test_slice_distribution_from_x_case_1(sequence< 0, 0>{}, sequence<16, 16>{}, sequence<0, 0, 0, 0, 0>{}); + test_slice_distribution_from_x_case_1(sequence<16, 16>{}, sequence<32, 32>{}, sequence<1, 0, 0, 0, 2>{}); + test_slice_distribution_from_x_case_1(sequence<32, 64>{}, sequence<48, 80>{}, sequence<2, 0, 0, 1, 0>{}); + test_slice_distribution_from_x_case_1(sequence<48, 208>{}, sequence<64, 224>{}, sequence<3, 0, 1, 1, 2>{}); + + test_slice_distribution_from_x_case_2(sequence< 0, 0>{}, sequence<12, 48>{}, sequence<0, 0, 0, 0, 0>{}); + test_slice_distribution_from_x_case_2(sequence<12, 144>{}, sequence<24, 192>{}, sequence<0, 1, 2, 2, 0>{}); +} + +// clang-format on +int main() { test_slice_distribution_from_x(); }