mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK_TILE] add scatter_gather (#1609)
[ROCm/composable_kernel commit: 4d7e063a0a]
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
|
||||
replicate,
|
||||
xor_t,
|
||||
offset,
|
||||
indexing,
|
||||
};
|
||||
|
||||
template <index_t NDimLow, index_t NDimUp>
|
||||
@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
struct indexing : public base_transform<1, 1>
|
||||
{
|
||||
static constexpr index_t NDimUp = 1;
|
||||
|
||||
using LowerIndex = multi_index<1>;
|
||||
using UpperIndex = multi_index<1>;
|
||||
|
||||
using UpLengths = decltype(make_tuple(UpLength{}));
|
||||
UpLengths up_lengths_;
|
||||
IndexingAdaptor iadaptor_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
|
||||
const IndexingAdaptor& iadaptor)
|
||||
: up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
|
||||
{
|
||||
return coord_transform_enum::indexing;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
|
||||
"wrong! inconsistent # of dimension");
|
||||
iadaptor_.calculate_lower_index(idx_low, idx_up);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
|
||||
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_always_mapped_to_valid_lower_index()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
IndexingAdaptor::is_known_at_compile_time();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("embed{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
//*******************************************************************************************************
|
||||
|
||||
template <typename LowLength>
|
||||
@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename UpLength, typename Indices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
|
||||
const Indices& indices)
|
||||
{
|
||||
// by default we use the simplest one
|
||||
return indexing<UpLength, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>>{
|
||||
up_lengths, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>{indices}};
|
||||
}
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
|
||||
{
|
||||
return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// pre-defined indexing adaptor used for indexing(scatter/gather)
|
||||
|
||||
// this version cache the index inside thread register(which is also prefered in real senario)
|
||||
// however it's user's responsibility that each thread only provide one indexing, which means
|
||||
// move coordinate will not change on this dim
|
||||
template <typename IndexingType>
|
||||
struct indexing_adaptor_onshot_cached
|
||||
{
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx)
|
||||
: cached_idx_(idx)
|
||||
{
|
||||
}
|
||||
IndexingType cached_idx_;
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(number<0>{}) = cached_idx_;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& /*idx_low*/,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
|
||||
|
||||
// pass the diff to lower, but not changing the actually index
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<IndexingType>::value;
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -210,3 +210,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL
|
||||
add_subdirectory(smfmac_op)
|
||||
endif()
|
||||
add_subdirectory(position_embedding)
|
||||
add_subdirectory(scatter_gather)
|
||||
|
||||
2
test/scatter_gather/CMakeLists.txt
Normal file
2
test/scatter_gather/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_test_executable(test_scatter_gather scatter_gather.cpp)
|
||||
# target_compile_options(test_scatter_gather PRIVATE -v --save-temps -Wno-gnu-line-marker)
|
||||
276
test/scatter_gather/scatter_gather.cpp
Normal file
276
test/scatter_gather/scatter_gather.cpp
Normal file
@@ -0,0 +1,276 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#ifndef TEST_SCATTER_GATHER_VERBOSE
|
||||
#define TEST_SCATTER_GATHER_VERBOSE 1
|
||||
#endif
|
||||
|
||||
#define HIP_CALL(call) \
|
||||
do \
|
||||
{ \
|
||||
hipError_t err = call; \
|
||||
if(err != hipSuccess) \
|
||||
{ \
|
||||
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \
|
||||
exit(0); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
/*
|
||||
TODO:
|
||||
This is a simple design of scatter/gather through indexing transform, with limitations
|
||||
We may design a scatter/gather adaptor layer directly inside tile window
|
||||
*/
|
||||
template <ck_tile::index_t ROW_TILE_SIZE = 8,
|
||||
ck_tile::index_t COL_TILE_SIZE = 32 * 8,
|
||||
ck_tile::index_t BLOCK_SIZE = 256,
|
||||
ck_tile::index_t ALIGNMENT = 8,
|
||||
typename INDEX_BUF_TYPE = ck_tile::index_t,
|
||||
typename DATA_TYPE = ck_tile::fp16_t>
|
||||
__global__ void row_scatter_gather(const INDEX_BUF_TYPE* src_row_idx_ptr,
|
||||
const INDEX_BUF_TYPE* dst_row_idx_ptr,
|
||||
const DATA_TYPE* src_ptr,
|
||||
DATA_TYPE* dst_ptr,
|
||||
ck_tile::index_t n_row_total,
|
||||
ck_tile::index_t /*n_row_select*/,
|
||||
ck_tile::index_t n_cols)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// some constexpr vars
|
||||
constexpr index_t vec = ALIGNMENT;
|
||||
static_assert(COL_TILE_SIZE % vec == 0);
|
||||
constexpr index_t col_lanes = COL_TILE_SIZE / vec;
|
||||
constexpr index_t warp_size = ck_tile::get_warp_size();
|
||||
static_assert(warp_size % col_lanes == 0);
|
||||
constexpr index_t row_lanes = warp_size / col_lanes;
|
||||
constexpr index_t num_warps = BLOCK_SIZE / warp_size;
|
||||
static_assert(ROW_TILE_SIZE % (num_warps * row_lanes) == 0);
|
||||
constexpr index_t row_repeat = ROW_TILE_SIZE / (num_warps * row_lanes);
|
||||
static_assert(
|
||||
row_repeat == 1,
|
||||
"currently indexing not support(and would be not performant) if row_repeat has more");
|
||||
|
||||
// tile partitioner
|
||||
index_t tile_col_idx = 0;
|
||||
index_t tile_row_idx = blockIdx.x * ROW_TILE_SIZE;
|
||||
|
||||
// create our tild distribution, which tell us the location of different threads
|
||||
constexpr auto src_dist = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<row_repeat, num_warps, row_lanes>, sequence<col_lanes, vec>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
const auto coord = src_dist.calculate_index();
|
||||
const auto row_coord = coord[number<0>{}] + tile_row_idx;
|
||||
|
||||
// load the current row index from the indexing buffer. we do not use ck_tile utility here
|
||||
INDEX_BUF_TYPE src_row_id = src_row_idx_ptr[row_coord];
|
||||
INDEX_BUF_TYPE dst_row_id = dst_row_idx_ptr[row_coord];
|
||||
|
||||
// printf("-- tid:%d, src_row_id:%d, dst_row_id:%d\n", static_cast<int>(threadIdx.x),
|
||||
// static_cast<int>(src_row_id), static_cast<int>(dst_row_id));
|
||||
|
||||
const auto src_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(src_ptr,
|
||||
make_tuple(n_row_total, n_cols),
|
||||
make_tuple(n_cols, 1),
|
||||
number<vec>{}, // alignement
|
||||
number<1>{});
|
||||
|
||||
const auto src_gather_view = transform_tensor_view(
|
||||
src_view,
|
||||
make_tuple(make_indexing_transform(
|
||||
n_row_total,
|
||||
src_row_id), // here we replace row_idx which is loaded from another buffer
|
||||
make_pass_through_transform(n_cols)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto src_tile = make_tile_window(src_gather_view,
|
||||
make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}),
|
||||
{tile_row_idx, tile_col_idx},
|
||||
src_dist);
|
||||
|
||||
const auto dst_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(dst_ptr,
|
||||
make_tuple(n_row_total, n_cols),
|
||||
make_tuple(n_cols, 1),
|
||||
number<vec>{},
|
||||
number<1>{});
|
||||
|
||||
const auto dst_scatter_view = transform_tensor_view(
|
||||
dst_view,
|
||||
make_tuple(make_indexing_transform(
|
||||
n_row_total,
|
||||
dst_row_id), // here we replace row_idx which is loaded from another buffer
|
||||
make_pass_through_transform(n_cols)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto dst_tile = make_tile_window(dst_scatter_view,
|
||||
make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}),
|
||||
{tile_row_idx, tile_col_idx},
|
||||
src_dist /*reuse distribution*/);
|
||||
|
||||
// we finished descriptor construction and index calculation, now start load/store
|
||||
for(auto i = 0; i < n_cols; i += COL_TILE_SIZE)
|
||||
{
|
||||
// note that scatter/gather are just the same API when doing load store as normal memory
|
||||
// operation
|
||||
auto data = load_tile(src_tile);
|
||||
store_tile(dst_tile, data);
|
||||
|
||||
move_tile_window(src_tile, {number<0>{}, number<COL_TILE_SIZE>{}});
|
||||
move_tile_window(dst_tile, {number<0>{}, number<COL_TILE_SIZE>{}});
|
||||
}
|
||||
}
|
||||
|
||||
union pixel
|
||||
{
|
||||
struct __attribute__((packed))
|
||||
{
|
||||
unsigned int r : 6;
|
||||
unsigned int c : 10;
|
||||
};
|
||||
ushort data;
|
||||
};
|
||||
|
||||
struct unique_linear_rand
|
||||
{
|
||||
unique_linear_rand(int capacity_) : capacity(capacity_) {}
|
||||
std::unordered_set<int> set;
|
||||
int gen()
|
||||
{
|
||||
if(static_cast<int>(set.size()) >= capacity)
|
||||
{
|
||||
printf("overflow, but will give you an number as well\n");
|
||||
return std::rand() % capacity;
|
||||
}
|
||||
while(1)
|
||||
{
|
||||
int r = std::rand() % capacity;
|
||||
if(set.count(r) == 1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
set.insert(r);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
int capacity;
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
int row_total = 64;
|
||||
int row_select = 8 * 2;
|
||||
int col = 256 * 2;
|
||||
using fp16_t = ck_tile::fp16_t;
|
||||
|
||||
constexpr int row_tile = 8;
|
||||
constexpr int col_tile = 256;
|
||||
|
||||
fp16_t* src = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t)));
|
||||
for(int i_r = 0; i_r < row_total; i_r++)
|
||||
{
|
||||
for(int i_c = 0; i_c < col; i_c++)
|
||||
{
|
||||
int i = i_r * col + i_c;
|
||||
pixel p;
|
||||
p.r = i_r;
|
||||
p.c = i_c;
|
||||
ushort d = p.data;
|
||||
src[i] = ck_tile::bit_cast<fp16_t>(d); // for simplicity, just cast
|
||||
}
|
||||
}
|
||||
|
||||
fp16_t* dst = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t)));
|
||||
int* src_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int)));
|
||||
int* dst_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int)));
|
||||
// std::srand(std::time(std::nullptr));
|
||||
// std::srand(11935);
|
||||
std::srand(std::time(nullptr));
|
||||
auto src_gen = unique_linear_rand(row_total);
|
||||
auto dst_gen = unique_linear_rand(row_total); // dst index must be unique. src is fine
|
||||
for(int i_r = 0; i_r < row_select; i_r++)
|
||||
{
|
||||
src_idx[i_r] = src_gen.gen();
|
||||
dst_idx[i_r] = dst_gen.gen();
|
||||
}
|
||||
|
||||
void* dev_src;
|
||||
void* dev_dst;
|
||||
void* dev_src_idx;
|
||||
void* dev_dst_idx;
|
||||
HIP_CALL(hipMalloc(&dev_src, row_total * col * sizeof(fp16_t)));
|
||||
HIP_CALL(hipMalloc(&dev_dst, row_total * col * sizeof(fp16_t)));
|
||||
HIP_CALL(hipMalloc(&dev_src_idx, row_select * sizeof(int)));
|
||||
HIP_CALL(hipMalloc(&dev_dst_idx, row_select * sizeof(int)));
|
||||
|
||||
HIP_CALL(hipMemcpy(dev_src, src, row_total * col * sizeof(fp16_t), hipMemcpyHostToDevice));
|
||||
HIP_CALL(hipMemcpy(dev_src_idx, src_idx, row_select * sizeof(int), hipMemcpyHostToDevice));
|
||||
HIP_CALL(hipMemcpy(dev_dst_idx, dst_idx, row_select * sizeof(int), hipMemcpyHostToDevice));
|
||||
|
||||
constexpr int bdim = 256;
|
||||
int gdim = (row_select + row_tile - 1) / row_tile;
|
||||
row_scatter_gather<row_tile, col_tile><<<gdim, bdim>>>(reinterpret_cast<int*>(dev_src_idx),
|
||||
reinterpret_cast<int*>(dev_dst_idx),
|
||||
reinterpret_cast<fp16_t*>(dev_src),
|
||||
reinterpret_cast<fp16_t*>(dev_dst),
|
||||
row_total,
|
||||
row_select,
|
||||
col);
|
||||
|
||||
HIP_CALL(hipMemcpy(dst, dev_dst, row_total * col * sizeof(fp16_t), hipMemcpyDeviceToHost));
|
||||
|
||||
#if TEST_SCATTER_GATHER_VERBOSE
|
||||
printf("select row:");
|
||||
for(int i_r = 0; i_r < row_select; i_r++)
|
||||
{
|
||||
printf("%d->%d->%d ", i_r, src_idx[i_r], dst_idx[i_r]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
|
||||
int err_cnt = 0;
|
||||
for(int i_r = 0; i_r < row_select; i_r++)
|
||||
{
|
||||
for(int i_c = 0; i_c < col; i_c++)
|
||||
{
|
||||
int i = dst_idx[i_r] * col + i_c;
|
||||
pixel p = ck_tile::bit_cast<pixel>(dst[i]);
|
||||
bool is_ok = p.r == src_idx[i_r] && p.c == i_c;
|
||||
if(!is_ok)
|
||||
{
|
||||
if(i_c == 0)
|
||||
printf("(%d)pixel: %dx%d -> %d\n", i_r, p.r, p.c, dst_idx[i_r]);
|
||||
err_cnt++;
|
||||
}
|
||||
}
|
||||
}
|
||||
#if TEST_SCATTER_GATHER_VERBOSE
|
||||
printf("err:%d\n", err_cnt);
|
||||
#endif
|
||||
|
||||
free(src);
|
||||
free(dst);
|
||||
free(src_idx);
|
||||
free(dst_idx);
|
||||
return err_cnt == 0 ? 0 : -1;
|
||||
}
|
||||
Reference in New Issue
Block a user