From 98033a68ce3bf9da8e7db8bff5ee23ffb0aa1274 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Wed, 12 Nov 2025 02:42:51 +0000 Subject: [PATCH] Merge commit '40d2ed0f2a442026c57dc17e6e7bd281b6c2535c' into develop --- include/ck_tile/core/container/sequence.hpp | 11 + include/ck_tile/core/tensor/load_tile.hpp | 51 ++++- .../core/tensor/load_tile_transpose.hpp | 60 +++++- .../core/tensor/static_distributed_tensor.hpp | 41 +++- include/ck_tile/core/tensor/store_tile.hpp | 51 +++++ include/ck_tile/core/tensor/tensor_view.hpp | 15 ++ .../ck_tile/core/tensor/tile_distribution.hpp | 27 ++- .../core/tensor/tile_scatter_gather.hpp | 5 +- include/ck_tile/core/tensor/tile_window.hpp | 195 +++++++++++++++--- .../ops/epilogue/default_2d_epilogue.hpp | 41 +++- .../ck_tile/ops/reduce/block/block_reduce.hpp | 2 +- 11 files changed, 441 insertions(+), 58 deletions(-) diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index cfec2237f9..1a88a98cbf 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -214,6 +214,17 @@ CK_TILE_HOST_DEVICE static void print(const sequence&) printf(">"); } +template +struct is_sequence : std::false_type +{ +}; +template +struct is_sequence> : std::true_type +{ +}; +template +inline constexpr bool is_sequence_v = is_sequence::value; + namespace impl { template struct __integer_sequence; diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 2e9ab0f5c6..1be4259e97 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -17,6 +17,19 @@ #include "ck_tile/core/tensor/null_tensor.hpp" namespace ck_tile { +// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. +template >> +CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load_with_offset( + offset, number{}, bool_constant{}); +} template CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, @@ -49,6 +62,23 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, tile_window, elementwise, number{}, bool_constant{}); } +// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. +template > && + std::is_class_v>> +CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile, + const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load_with_offset( + offset, dst_tile, number{}, bool_constant{}); +} + template {}, bool_constant{}, bool_constant{}); } +// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. +template > && + std::is_class_v>> +CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile, + const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.async_load_with_offset( + offset, lds_tile, number{}, bool_constant{}); +} + template = {}, bool_constant = {}) { - return tile_window.async_load( - lds_tile, number{}, bool_constant{}); + return async_load_tile_with_offset( + lds_tile, tile_window, 0, number{}, bool_constant{}); } template ::distr_encoding_valid, Policy>> -CK_TILE_DEVICE auto -load_tile_transpose(const tile_window_with_static_distribution& tile_window) +CK_TILE_DEVICE auto load_tile_transpose_with_offset( + const tile_window_with_static_distribution& __restrict__ tile_window, + index_t offset) { using OutTileDstrEncode = typename OutputTileDistributionTraits< typename TileDistribution_::DstrEncode, typename BottomTensorView_::DataType>::TransposedDstrEncode; auto out_tensor = make_static_distributed_tensor( make_static_tile_distribution(OutTileDstrEncode{})); - auto trans_tensor = tile_window.template load_transpose(); + auto trans_tensor = tile_window.template load_transpose_with_offset(offset); constexpr auto input_distr = TileDistribution_{}; constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); @@ -443,4 +446,49 @@ load_tile_transpose(const tile_window_with_static_distribution, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE auto +load_tile_transpose(const tile_window_with_static_distribution& __restrict__ tile_window) +{ + return load_tile_transpose_with_offset(tile_window, 0); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index b73a27c8d5..5228ad978a 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -155,11 +155,11 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTi // get X indices from tuple of tile_distributed_index<> template -CK_TILE_HOST_DEVICE constexpr auto -get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, - DistributedIndices distributed_indices) +CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices( + StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices, + decltype(get_partition_index(tile_distribution)) partition_index) { - const auto partition_index = detail::get_partition_index(tile_distribution); constexpr auto y_indices = tile_distribution.get_y_indices_from_distributed_indices(distributed_indices); @@ -170,6 +170,16 @@ get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, return x_coord.get_bottom_index(); } +// get X indices from tuple of tile_distributed_index<> +template +CK_TILE_HOST_DEVICE constexpr auto +get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices) +{ + return get_x_indices_from_distributed_indices( + tile_distribution, distributed_indices, get_partition_index(tile_distribution)); +} + template CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor& out_tensor, @@ -192,6 +202,29 @@ set_tile_if(static_distributed_tensor& out_ten }); } +template +CK_TILE_HOST_DEVICE void +set_tile_if(static_distributed_tensor& out_tensor, + DataType value, + XIndicesPredicate predicate, + decltype(get_partition_index(std::declval())) partition_index) +{ + constexpr auto out_spans = + static_distributed_tensor::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + StaticTileDistribution{}, distributed_indices, partition_index); + + if(predicate(x_indices)) + { + out_tensor(distributed_indices) = value; + } + }); + }); +} + // this function used inside span loop over template CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number) diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index d5a716664d..b535b40534 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -38,6 +39,31 @@ store_tile(tile_window_with_static_lengths& t tile_window.store(dstr_tensor); } +template +CK_TILE_DEVICE void +store_tile(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor, + decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr, + partition_index); + + tile_window.store(dstr_tensor); +} + template +CK_TILE_DEVICE void +store_tile_raw(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor, + decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr, + partition_index); + + tile_window.store_raw(dstr_tensor); +} + template +struct is_tensor_view : std::false_type +{ +}; +template +struct is_tensor_view> : std::true_type +{ +}; +template <> +struct is_tensor_view : std::true_type +{ +}; +template +inline constexpr bool is_tensor_view_v = is_tensor_view::value; + template CK_TILE_HOST_DEVICE auto get_partition_index(Distribution) { - return Distribution::_get_partition_index(); + return Distribution::get_partition_index(); } -} // namespace detail // distributed span template @@ -91,7 +89,7 @@ struct tile_distribution CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; } - CK_TILE_HOST_DEVICE static auto _get_partition_index() + CK_TILE_HOST_DEVICE static auto get_partition_index() { // only support warp-tile and block-tile static_assert(NDimP == 1 or NDimP == 2, "wrong!"); @@ -172,9 +170,9 @@ struct tile_distribution } #endif - template + template CK_TILE_HOST_DEVICE auto - calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const + calculate_index(const PartitionIndex& ps_idx = get_partition_index()) const { const auto ps_ys_idx = container_concat(ps_idx, array{0}); const auto window_adaptor_thread_coord_tmp = @@ -230,6 +228,23 @@ struct tile_distribution } }; +template +struct is_tile_distribution : std::false_type +{ +}; +template +struct is_tile_distribution> : std::true_type +{ +}; +template +inline constexpr bool is_tile_distribution_v = is_tile_distribution::value; + namespace detail { template diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 4b04fd513d..e77ca805bb 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -189,8 +189,7 @@ struct tile_scatter_gather // need investigation const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_distribution), - array{0})); + container_concat(get_partition_index(tile_distribution), array{0})); #endif BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = @@ -836,7 +835,7 @@ struct tile_scatter_gather // need investigation const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_dstr_.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_dstr_), array{0})); + container_concat(get_partition_index(tile_dstr_), array{0})); #endif BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index cfa2420f2f..1123ce7604 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -12,6 +12,7 @@ #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tensor_view.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/utility/functional.hpp" @@ -67,18 +68,54 @@ struct tile_window_with_static_distribution const typename Base::BottomTensorView& bottom_tensor_view, const typename Base::WindowLengths& window_lengths, const typename Base::BottomTensorIndex& window_origin, - const typename Base::TileDstr& tile_distribution) + const typename Base::TileDstr& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index) : pre_computed_coords_{} { - this->window_origin_ = window_origin; - this->window_lengths_ = window_lengths; - this->bottom_tensor_view_ = bottom_tensor_view; - this->tile_dstr_ = tile_distribution; + this->window_origin_ = window_origin; + this->window_lengths_ = window_lengths; + this->bottom_tensor_view_ = bottom_tensor_view; + this->tile_dstr_ = tile_distribution; + + pre_computed_coords_ = + prepare_coords(bottom_tensor_view, window_origin, tile_distribution, partition_index); + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + auto use_lane_id_0 = partition_index; + use_lane_id_0[1] = 0; + + pre_computed_warp_coords_ = + prepare_coords(bottom_tensor_view, window_origin, tile_distribution, use_lane_id_0); + } + } + + CK_TILE_DEVICE constexpr tile_window_with_static_distribution( + const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::WindowLengths& window_lengths, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution) + : tile_window_with_static_distribution(bottom_tensor_view, + window_lengths, + window_origin, + tile_distribution, + get_partition_index(tile_distribution)) + { + } + + CK_TILE_DEVICE constexpr auto + prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index) const + { + array, NumCoord> + coords; + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_distribution), - array{0})); + container_concat(partition_index, multi_index{0})); typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); @@ -105,18 +142,31 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - pre_computed_coords_(iCoord) = - make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + coords(iCoord) = make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); }); + + return coords; } template CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const + { + return load_with_offset( + 0, number{}, bool_constant{}); + } + + template + CK_TILE_DEVICE auto load_with_offset(index_t offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - load(dst_tensor, number{}, bool_constant{}); + load_with_offset(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -236,6 +286,19 @@ struct tile_window_with_static_distribution CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, number = {}, bool_constant = {}) const + { + load_with_offset( + 0, dst_tensor, number{}, bool_constant{}); + } + + template >>> + CK_TILE_DEVICE auto load_with_offset(index_t offset, + DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const { using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; @@ -258,7 +321,7 @@ struct tile_window_with_static_distribution // read from bottom tensor const vector_t vec_value = this->get_bottom_tensor_view().template get_vectorized_elements( - bottom_tensor_thread_coord, 0, bool_constant{}); + bottom_tensor_thread_coord, offset, bool_constant{}); // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( @@ -450,10 +513,12 @@ struct tile_window_with_static_distribution template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, - number = {}, - bool_constant = {}) const + bool oob_conditional_check = true, + typename = std::enable_if_t>>> + CK_TILE_DEVICE auto async_load_with_offset(index_t offset, + LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; using LdsDataType = typename LdsTileWindow::DataType; @@ -472,12 +537,15 @@ struct tile_window_with_static_distribution auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0]; + auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1]; + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; // Use precomputed window origin auto lds_bottom_tensor_thread_idx = - window_origin + window_adaptor_thread_coord.get_bottom_index(); + window_origin + window_adaptor_warp_coord.get_bottom_index(); // Use precomputed tensor descriptor const auto lds_coord = @@ -490,7 +558,7 @@ struct tile_window_with_static_distribution this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, - number<0>{}, + offset, bool_constant{}); // Move thread coordinate if not last access @@ -503,18 +571,33 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys); } }); }); } template - CK_TILE_DEVICE auto load_transpose() const + CK_TILE_DEVICE auto load_transpose(number = {}, + bool_constant = {}) const + { + return this->template load_transpose_with_offset( + 0, number{}, bool_constant{}); + } + + template + CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - this->template load_transpose( - dst_tensor, number{}, bool_constant{}); + this->template load_transpose_with_offset(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -522,9 +605,10 @@ struct tile_window_with_static_distribution typename DistributedTensor, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true> - CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor, - number = {}, - bool_constant = {}) const + CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset, + DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const { using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; @@ -550,7 +634,7 @@ struct tile_window_with_static_distribution const vector_t vec_value = this->get_bottom_tensor_view() .template get_transpose_vectorized_elements( - bottom_tensor_thread_coord, 0); + bottom_tensor_thread_coord, offset); // write into distributed tensor static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto orig_idx_ys = generate_tuple( @@ -862,16 +946,26 @@ struct tile_window_with_static_distribution pre_computed_coords_(iCoord)(I1), step); }); + + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_warp_coords_(iCoord)(I1), + step); + }); + } } CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&) { // TODO: this use less register for FA, but more register for GEMM // need investigation - const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - this->tile_dstr_.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(this->tile_dstr_), - array{0})); + const auto window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(this->tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(get_partition_index(this->tile_dstr_), + array{0})); typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); @@ -908,6 +1002,12 @@ struct tile_window_with_static_distribution // per-thread coordinate for bottom tensor array, NumCoord> pre_computed_coords_; + // pre_computed_warp_coords_ exists only in the global memory tile_window + std::conditional_t< + Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global, + array, NumCoord>, + std::byte> + pre_computed_warp_coords_; }; // TODO: use strategy @@ -929,6 +1029,27 @@ make_tile_window(const TensorView_& tensor_view, tensor_view, window_lengths, origin, tile_distribution}; } +template && + is_tile_distribution_v>> +CK_TILE_DEVICE constexpr auto +make_tile_window(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index, + number = {}) +{ + return tile_window_with_static_distribution, + remove_cvref_t, + remove_cvref_t, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution, partition_index}; +} + // this version can't be called in a constexpr context template +CK_TILE_DEVICE constexpr auto +make_tile_window(const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index) +{ + return make_tile_window(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + partition_index); +} + template CK_TILE_DEVICE constexpr auto make_tile_window_raw(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution) { - auto w = make_tile_window(tile_window.get_bottom_tensor_view(), - tile_window.get_window_lengths(), - tile_window.get_window_origin(), - tile_distribution); + auto w = make_tile_window(tile_window, tile_distribution); w.init_raw(); return w; } diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 2843966cd7..8cf47c46e7 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -93,13 +93,27 @@ struct Default2DEpilogue const DsDramWindows& ds_dram_windows, void* = nullptr) const { + constexpr bool is_partition_index = + std::is_convertible_v; + const auto storeOrUpdateTile = [&](const auto& o_tile) { // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { if constexpr(MemoryOperation == memory_operation_enum::set) { - store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + if constexpr(is_partition_index) + { + store_tile_raw(o_dram_window_tmp, + cast_tile(o_tile), + /*partition_index=*/ds_dram_windows); + } + else + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + } } else { @@ -111,16 +125,35 @@ struct Default2DEpilogue { if constexpr(MemoryOperation == memory_operation_enum::set) { - store_tile(o_dram_window_tmp, cast_tile(o_tile)); + if constexpr(is_partition_index) + { + store_tile(o_dram_window_tmp, + cast_tile(o_tile), + /*partition_index=*/ds_dram_windows); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_tile)); + } } else { - update_tile(o_dram_window_tmp, cast_tile(o_tile)); + if constexpr(is_partition_index) + { + update_tile(o_dram_window_tmp, + cast_tile(o_tile), + /*partition_index=*/ds_dram_windows); + } + else + { + update_tile(o_dram_window_tmp, cast_tile(o_tile)); + } } } }; - if constexpr(!std::is_same_v && Problem::NumDTensor >= 1) + if constexpr(!std::is_same_v && !is_partition_index && + Problem::NumDTensor >= 1) { using elementwise_result_t = decltype(load_tile( make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(), diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 7a10d1fa56..2fd8a48eee 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -32,7 +32,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, constexpr index_t idim_p_lane = NDimP - 1; - const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution()); + const auto ps_idx = get_partition_index(acc_tensor.get_tile_distribution()); const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();