[CK_TILE] Share partition index across threads and specify offset in load_tile()/async_load_tile()/load_tile_transpose() (#2905)

* Allow sharing partition index across threads

* Fix typo PartitoinIndex -> PartitionIndex

* Remove C++20 'requires' usages

* Add missing template arguments

* Fix load_tile() overload ambiguity issue

* Use SFINAE to exclude invalid arguments

* Add additional offset parameter to the async_load_tile()

* Remove async_load_tile() default argument to avoid ambiguity

* Extract tile_window coordinate compute logic as method

* Use warp-shared LDS base address in tile_window::async_load()

* Add constraint to tile_window::load() templates

* Fix wrong type traits is_class_v<> usages

* Add missing constraint to async_load_tile()

* Add missing tile_window::load() overload

* Add more constraint to avoid load_tile() call ambiguity

* Rename ParitionIndex as ReplacementPartitionIndex

* Update pre_computed_warp_coords_ in move_extended()

* Fix inconsistency between template parameters and documentation

* Allow specifying pre-computed parition index

* Add type straits is_sequence<> & is_tile_distribution<>

* Add type straits is_tensor_view<>

* Add type constraints to make_tile_window() templates

* Allow passing partition_index to set_tile_if()

* Allow specifying partition_index to store_tile()

* Add missing template parameter of replace_bottom_tensor_view()

* Allow passing partition_index to Default2DEpilogue

* Make get_partition_index() public

* Add _with_offset() postfix to avoid resolution error

* Remove ReplacementPartitionIndex template param

* Add missing comments

* Add load_tile_transpose_with_offset() overload
This commit is contained in:
Po Yen Chen
2025-11-12 10:26:14 +08:00
committed by GitHub
parent 92c1f4981a
commit 40d2ed0f2a
11 changed files with 441 additions and 58 deletions

View File

@@ -32,7 +32,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
constexpr index_t idim_p_lane = NDimP - 1;
const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution());
const auto ps_idx = get_partition_index(acc_tensor.get_tile_distribution());
const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();