From cb33803f3a790e3155df239918df237319c8517c Mon Sep 17 00:00:00 2001 From: valarLip <103567126+valarLip@users.noreply.github.com> Date: Tue, 29 Oct 2024 18:19:29 +0800 Subject: [PATCH] [CK_TILE] add scatter_gather (#1609) [ROCm/composable_kernel commit: 4d7e063a0a2dfb183bc3876b1ff021829aabd38b] --- include/ck_tile/core.hpp | 1 + .../core/algorithm/coordinate_transform.hpp | 104 +++++++ .../core/algorithm/indexing_adaptor.hpp | 60 ++++ test/CMakeLists.txt | 1 + test/scatter_gather/CMakeLists.txt | 2 + test/scatter_gather/scatter_gather.cpp | 276 ++++++++++++++++++ 6 files changed, 444 insertions(+) create mode 100644 include/ck_tile/core/algorithm/indexing_adaptor.hpp create mode 100644 test/scatter_gather/CMakeLists.txt create mode 100644 test/scatter_gather/scatter_gather.cpp diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 56dfbd636b..14991d375a 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core/algorithm/cluster_descriptor.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/algorithm/indexing_adaptor.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/arch.hpp" diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 5c7e489804..aaa7db2574 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -23,6 +23,7 @@ enum struct coord_transform_enum replicate, xor_t, offset, + indexing, }; template @@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1> } }; +template +struct indexing : public base_transform<1, 1> +{ + static constexpr index_t NDimUp = 1; + + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(UpLength{})); + UpLengths up_lengths_; + IndexingAdaptor iadaptor_; + + CK_TILE_HOST_DEVICE constexpr indexing() = default; + + CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length, + const IndexingAdaptor& iadaptor) + : up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return coord_transform_enum::indexing; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + iadaptor_.calculate_lower_index(idx_low, idx_up); + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up) const + { + // TODO: nonthing changed here + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp && + LowIdx::size() == 1 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + + iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up); + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + IndexingAdaptor::is_known_at_compile_time(); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("embed{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + printf("}"); + } +}; + //******************************************************************************************************* template @@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le } } // namespace ck_tile + +#include "ck_tile/core/algorithm/indexing_adaptor.hpp" +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths, + const Indices& indices) +{ + // by default we use the simplest one + return indexing>>{ + up_lengths, indexing_adaptor_onshot_cached>{indices}}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor) +{ + return indexing{up_lengths, iadaptor}; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/algorithm/indexing_adaptor.hpp b/include/ck_tile/core/algorithm/indexing_adaptor.hpp new file mode 100644 index 0000000000..ef59abdc99 --- /dev/null +++ b/include/ck_tile/core/algorithm/indexing_adaptor.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { +// pre-defined indexing adaptor used for indexing(scatter/gather) + +// this version cache the index inside thread register(which is also prefered in real senario) +// however it's user's responsibility that each thread only provide one indexing, which means +// move coordinate will not change on this dim +template +struct indexing_adaptor_onshot_cached +{ + + CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default; + CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx) + : cached_idx_(idx) + { + } + IndexingType cached_idx_; + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& /*idx_up*/) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = cached_idx_; + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& /*idx_low*/, + const UpIdx& /*idx_up*/) const + { + // TODO: nonthing changed here + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}]; + + // pass the diff to lower, but not changing the actually index + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } +}; +} // namespace ck_tile diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b836dd687e..b12ced5244 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -210,3 +210,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL add_subdirectory(smfmac_op) endif() add_subdirectory(position_embedding) +add_subdirectory(scatter_gather) diff --git a/test/scatter_gather/CMakeLists.txt b/test/scatter_gather/CMakeLists.txt new file mode 100644 index 0000000000..cc327d42db --- /dev/null +++ b/test/scatter_gather/CMakeLists.txt @@ -0,0 +1,2 @@ +add_test_executable(test_scatter_gather scatter_gather.cpp) +# target_compile_options(test_scatter_gather PRIVATE -v --save-temps -Wno-gnu-line-marker) diff --git a/test/scatter_gather/scatter_gather.cpp b/test/scatter_gather/scatter_gather.cpp new file mode 100644 index 0000000000..439e792dd8 --- /dev/null +++ b/test/scatter_gather/scatter_gather.cpp @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" + +#ifndef TEST_SCATTER_GATHER_VERBOSE +#define TEST_SCATTER_GATHER_VERBOSE 1 +#endif + +#define HIP_CALL(call) \ + do \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + printf("[hiperror](%d) fail to call %s", static_cast(err), #call); \ + exit(0); \ + } \ + } while(0) + +/* +TODO: +This is a simple design of scatter/gather through indexing transform, with limitations +We may design a scatter/gather adaptor layer directly inside tile window +*/ +template +__global__ void row_scatter_gather(const INDEX_BUF_TYPE* src_row_idx_ptr, + const INDEX_BUF_TYPE* dst_row_idx_ptr, + const DATA_TYPE* src_ptr, + DATA_TYPE* dst_ptr, + ck_tile::index_t n_row_total, + ck_tile::index_t /*n_row_select*/, + ck_tile::index_t n_cols) +{ + using namespace ck_tile; + + // some constexpr vars + constexpr index_t vec = ALIGNMENT; + static_assert(COL_TILE_SIZE % vec == 0); + constexpr index_t col_lanes = COL_TILE_SIZE / vec; + constexpr index_t warp_size = ck_tile::get_warp_size(); + static_assert(warp_size % col_lanes == 0); + constexpr index_t row_lanes = warp_size / col_lanes; + constexpr index_t num_warps = BLOCK_SIZE / warp_size; + static_assert(ROW_TILE_SIZE % (num_warps * row_lanes) == 0); + constexpr index_t row_repeat = ROW_TILE_SIZE / (num_warps * row_lanes); + static_assert( + row_repeat == 1, + "currently indexing not support(and would be not performant) if row_repeat has more"); + + // tile partitioner + index_t tile_col_idx = 0; + index_t tile_row_idx = blockIdx.x * ROW_TILE_SIZE; + + // create our tild distribution, which tell us the location of different threads + constexpr auto src_dist = make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + const auto coord = src_dist.calculate_index(); + const auto row_coord = coord[number<0>{}] + tile_row_idx; + + // load the current row index from the indexing buffer. we do not use ck_tile utility here + INDEX_BUF_TYPE src_row_id = src_row_idx_ptr[row_coord]; + INDEX_BUF_TYPE dst_row_id = dst_row_idx_ptr[row_coord]; + + // printf("-- tid:%d, src_row_id:%d, dst_row_id:%d\n", static_cast(threadIdx.x), + // static_cast(src_row_id), static_cast(dst_row_id)); + + const auto src_view = + make_naive_tensor_view(src_ptr, + make_tuple(n_row_total, n_cols), + make_tuple(n_cols, 1), + number{}, // alignement + number<1>{}); + + const auto src_gather_view = transform_tensor_view( + src_view, + make_tuple(make_indexing_transform( + n_row_total, + src_row_id), // here we replace row_idx which is loaded from another buffer + make_pass_through_transform(n_cols)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto src_tile = make_tile_window(src_gather_view, + make_tuple(number{}, number{}), + {tile_row_idx, tile_col_idx}, + src_dist); + + const auto dst_view = + make_naive_tensor_view(dst_ptr, + make_tuple(n_row_total, n_cols), + make_tuple(n_cols, 1), + number{}, + number<1>{}); + + const auto dst_scatter_view = transform_tensor_view( + dst_view, + make_tuple(make_indexing_transform( + n_row_total, + dst_row_id), // here we replace row_idx which is loaded from another buffer + make_pass_through_transform(n_cols)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto dst_tile = make_tile_window(dst_scatter_view, + make_tuple(number{}, number{}), + {tile_row_idx, tile_col_idx}, + src_dist /*reuse distribution*/); + + // we finished descriptor construction and index calculation, now start load/store + for(auto i = 0; i < n_cols; i += COL_TILE_SIZE) + { + // note that scatter/gather are just the same API when doing load store as normal memory + // operation + auto data = load_tile(src_tile); + store_tile(dst_tile, data); + + move_tile_window(src_tile, {number<0>{}, number{}}); + move_tile_window(dst_tile, {number<0>{}, number{}}); + } +} + +union pixel +{ + struct __attribute__((packed)) + { + unsigned int r : 6; + unsigned int c : 10; + }; + ushort data; +}; + +struct unique_linear_rand +{ + unique_linear_rand(int capacity_) : capacity(capacity_) {} + std::unordered_set set; + int gen() + { + if(static_cast(set.size()) >= capacity) + { + printf("overflow, but will give you an number as well\n"); + return std::rand() % capacity; + } + while(1) + { + int r = std::rand() % capacity; + if(set.count(r) == 1) + { + continue; + } + set.insert(r); + return r; + } + } + + int capacity; +}; + +int main() +{ + int row_total = 64; + int row_select = 8 * 2; + int col = 256 * 2; + using fp16_t = ck_tile::fp16_t; + + constexpr int row_tile = 8; + constexpr int col_tile = 256; + + fp16_t* src = reinterpret_cast(malloc(row_total * col * sizeof(fp16_t))); + for(int i_r = 0; i_r < row_total; i_r++) + { + for(int i_c = 0; i_c < col; i_c++) + { + int i = i_r * col + i_c; + pixel p; + p.r = i_r; + p.c = i_c; + ushort d = p.data; + src[i] = ck_tile::bit_cast(d); // for simplicity, just cast + } + } + + fp16_t* dst = reinterpret_cast(malloc(row_total * col * sizeof(fp16_t))); + int* src_idx = reinterpret_cast(malloc(row_select * sizeof(int))); + int* dst_idx = reinterpret_cast(malloc(row_select * sizeof(int))); + // std::srand(std::time(std::nullptr)); + // std::srand(11935); + std::srand(std::time(nullptr)); + auto src_gen = unique_linear_rand(row_total); + auto dst_gen = unique_linear_rand(row_total); // dst index must be unique. src is fine + for(int i_r = 0; i_r < row_select; i_r++) + { + src_idx[i_r] = src_gen.gen(); + dst_idx[i_r] = dst_gen.gen(); + } + + void* dev_src; + void* dev_dst; + void* dev_src_idx; + void* dev_dst_idx; + HIP_CALL(hipMalloc(&dev_src, row_total * col * sizeof(fp16_t))); + HIP_CALL(hipMalloc(&dev_dst, row_total * col * sizeof(fp16_t))); + HIP_CALL(hipMalloc(&dev_src_idx, row_select * sizeof(int))); + HIP_CALL(hipMalloc(&dev_dst_idx, row_select * sizeof(int))); + + HIP_CALL(hipMemcpy(dev_src, src, row_total * col * sizeof(fp16_t), hipMemcpyHostToDevice)); + HIP_CALL(hipMemcpy(dev_src_idx, src_idx, row_select * sizeof(int), hipMemcpyHostToDevice)); + HIP_CALL(hipMemcpy(dev_dst_idx, dst_idx, row_select * sizeof(int), hipMemcpyHostToDevice)); + + constexpr int bdim = 256; + int gdim = (row_select + row_tile - 1) / row_tile; + row_scatter_gather<<>>(reinterpret_cast(dev_src_idx), + reinterpret_cast(dev_dst_idx), + reinterpret_cast(dev_src), + reinterpret_cast(dev_dst), + row_total, + row_select, + col); + + HIP_CALL(hipMemcpy(dst, dev_dst, row_total * col * sizeof(fp16_t), hipMemcpyDeviceToHost)); + +#if TEST_SCATTER_GATHER_VERBOSE + printf("select row:"); + for(int i_r = 0; i_r < row_select; i_r++) + { + printf("%d->%d->%d ", i_r, src_idx[i_r], dst_idx[i_r]); + } + printf("\n"); +#endif + + int err_cnt = 0; + for(int i_r = 0; i_r < row_select; i_r++) + { + for(int i_c = 0; i_c < col; i_c++) + { + int i = dst_idx[i_r] * col + i_c; + pixel p = ck_tile::bit_cast(dst[i]); + bool is_ok = p.r == src_idx[i_r] && p.c == i_c; + if(!is_ok) + { + if(i_c == 0) + printf("(%d)pixel: %dx%d -> %d\n", i_r, p.r, p.c, dst_idx[i_r]); + err_cnt++; + } + } + } +#if TEST_SCATTER_GATHER_VERBOSE + printf("err:%d\n", err_cnt); +#endif + + free(src); + free(dst); + free(src_idx); + free(dst_idx); + return err_cnt == 0 ? 0 : -1; +}