mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] Share partition index across threads and specify offset in load_tile()/async_load_tile()/load_tile_transpose() (#2905)
* Allow sharing partition index across threads * Fix typo PartitoinIndex -> PartitionIndex * Remove C++20 'requires' usages * Add missing template arguments * Fix load_tile() overload ambiguity issue * Use SFINAE to exclude invalid arguments * Add additional offset parameter to the async_load_tile() * Remove async_load_tile() default argument to avoid ambiguity * Extract tile_window coordinate compute logic as method * Use warp-shared LDS base address in tile_window::async_load() * Add constraint to tile_window::load() templates * Fix wrong type traits is_class_v<> usages * Add missing constraint to async_load_tile() * Add missing tile_window::load() overload * Add more constraint to avoid load_tile() call ambiguity * Rename ParitionIndex as ReplacementPartitionIndex * Update pre_computed_warp_coords_ in move_extended() * Fix inconsistency between template parameters and documentation * Allow specifying pre-computed parition index * Add type straits is_sequence<> & is_tile_distribution<> * Add type straits is_tensor_view<> * Add type constraints to make_tile_window() templates * Allow passing partition_index to set_tile_if() * Allow specifying partition_index to store_tile() * Add missing template parameter of replace_bottom_tensor_view() * Allow passing partition_index to Default2DEpilogue * Make get_partition_index() public * Add _with_offset() postfix to avoid resolution error * Remove ReplacementPartitionIndex template param * Add missing comments * Add load_tile_transpose_with_offset() overload
This commit is contained in:
@@ -93,13 +93,27 @@ struct Default2DEpilogue
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* = nullptr) const
|
||||
{
|
||||
constexpr bool is_partition_index =
|
||||
std::is_convertible_v<decltype(ds_dram_windows),
|
||||
decltype(get_partition_index(
|
||||
o_acc_tile.get_tile_distribution()))>;
|
||||
|
||||
const auto storeOrUpdateTile = [&](const auto& o_tile) {
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
if constexpr(is_partition_index)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp,
|
||||
cast_tile<ODataType>(o_tile),
|
||||
/*partition_index=*/ds_dram_windows);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -111,16 +125,35 @@ struct Default2DEpilogue
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
if constexpr(is_partition_index)
|
||||
{
|
||||
store_tile(o_dram_window_tmp,
|
||||
cast_tile<ODataType>(o_tile),
|
||||
/*partition_index=*/ds_dram_windows);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
if constexpr(is_partition_index)
|
||||
{
|
||||
update_tile(o_dram_window_tmp,
|
||||
cast_tile<ODataType>(o_tile),
|
||||
/*partition_index=*/ds_dram_windows);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(!std::is_same_v<DsDramWindows, std::nullptr_t> && Problem::NumDTensor >= 1)
|
||||
if constexpr(!std::is_same_v<DsDramWindows, std::nullptr_t> && !is_partition_index &&
|
||||
Problem::NumDTensor >= 1)
|
||||
{
|
||||
using elementwise_result_t = decltype(load_tile(
|
||||
make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
|
||||
|
||||
@@ -32,7 +32,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution());
|
||||
const auto ps_idx = get_partition_index(acc_tensor.get_tile_distribution());
|
||||
const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
Reference in New Issue
Block a user