mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Merge branch 'develop' into ck_tile/rmsnorm
This commit is contained in:
13
example/ck_tile/06_permute/CMakeLists.txt
Normal file
13
example/ck_tile/06_permute/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp)
|
||||
|
||||
if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL)
|
||||
# set(PERMUTE_USE_ALTERNATIVE_IMPL false)
|
||||
set(PERMUTE_USE_ALTERNATIVE_IMPL true)
|
||||
endif()
|
||||
if(PERMUTE_USE_ALTERNATIVE_IMPL)
|
||||
target_compile_options(tile_example_permute PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL)
|
||||
target_sources(tile_example_permute PRIVATE alternative_impl/matrix_core_swizzle.cpp)
|
||||
endif()
|
||||
# target_compile_options(tile_example_permute PRIVATE -v --save-temps -Wno-gnu-line-marker)
|
||||
46
example/ck_tile/06_permute/README.md
Normal file
46
example/ck_tile/06_permute/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# permute
|
||||
|
||||
This folder contains example for permute kernel, which is similiar to [torch.permute](https://pytorch.org/docs/stable/generated/torch.permute.html) (combined with [torch.contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html)). Currently we implement a generic permute kernel that support up to rank 8 arbitrary permutation with a single kernel instance. Performance is not the first consideration, we prefer a simple and general kernel implementation using `ck_tile` in this example.
|
||||
|
||||
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-prec data type. fp16/bf16/fp32 (default:fp16)
|
||||
-shape the shape of the input tensor (default:2,3,4)
|
||||
-perm permute perm (default:2,1,0)
|
||||
```
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_permute -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_permute`
|
||||
|
||||
|
||||
## some examples
|
||||
```
|
||||
# torch
|
||||
x=torch.randn(2,3,4,6)
|
||||
y=x.permute(0,3,2,1).contiguous()
|
||||
|
||||
# ck_tile
|
||||
./build/bin/tile_example_permute -shape=2,3,4,6 -perm=0,3,2,1
|
||||
```
|
||||
|
||||
or you can try the smoke_test
|
||||
```
|
||||
# in the root of ck_tile, after you build this example
|
||||
sh example/ck_tile/06_permute/script/smoke_test.sh
|
||||
```
|
||||
|
||||
### alternative implementation
|
||||
we have an alternative implementation under `alternative_impl/` folder, that can swizzle the tensor to be more friendly for data loading for matrix core layout. This can be enabled when dealing with a `rank-7` tensor, with a fixed pattern of either `0,1,4,2,5,3,6` or `0,1,2,4,5,3,6`. There are other shape limitation of this implementation, check the source code of `permute.cpp` for detail.
|
||||
```
|
||||
# example
|
||||
./build/bin/tile_example_permute -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 # b_n0_k0_n1_k1_n2_k2
|
||||
./build/bin/tile_example_permute -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 # b_n0_n1_k0_k1_n2_k2
|
||||
```
|
||||
@@ -0,0 +1,98 @@
|
||||
#include "matrix_core_swizzle.hpp"
|
||||
#include "matrix_core_swizzle_kernel.hpp"
|
||||
|
||||
float matrix_core_swizzle(matrix_core_swizzle_traits t,
|
||||
matrix_core_swizzle_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
if(t.inst.compare("32x32x8") == 0)
|
||||
{
|
||||
constexpr int BLOCK_SIZE = 256;
|
||||
constexpr int NPerBlock = 256;
|
||||
constexpr int KPerBlock = 128;
|
||||
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
|
||||
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
}
|
||||
else if(t.inst.compare("16x16x16") == 0)
|
||||
{
|
||||
constexpr int BLOCK_SIZE = 256;
|
||||
constexpr int NPerBlock = 256;
|
||||
constexpr int KPerBlock = 128;
|
||||
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
|
||||
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "matrix_core_swizzle_kernel.hpp"
|
||||
#include <string>
|
||||
|
||||
struct matrix_core_swizzle_traits
|
||||
{
|
||||
std::string data_type; // fp16 only
|
||||
std::string inst; // 32x32x8, 16x16x16
|
||||
std::string permute; //
|
||||
};
|
||||
|
||||
using matrix_core_swizzle_args = matrix_core_swizzle_host_args;
|
||||
|
||||
// host API
|
||||
float matrix_core_swizzle(matrix_core_swizzle_traits,
|
||||
matrix_core_swizzle_args,
|
||||
const ck_tile::stream_config&);
|
||||
@@ -0,0 +1,413 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
// if set to 1, slightly more instructions generated to calculate address
|
||||
#ifndef MERGE_2D_013425
|
||||
#define MERGE_2D_013425 0
|
||||
#endif
|
||||
|
||||
enum class matrix_core_inst_enum
|
||||
{
|
||||
MFMA_32x32x8_F16 = 0,
|
||||
MFMA_16x16x16_F16 = 1,
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
template <matrix_core_inst_enum>
|
||||
struct to_warp_gemm;
|
||||
|
||||
template <>
|
||||
struct to_warp_gemm<matrix_core_inst_enum::MFMA_32x32x8_F16>
|
||||
{
|
||||
using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
|
||||
{
|
||||
using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16;
|
||||
};
|
||||
} // namespace detail
|
||||
template <matrix_core_inst_enum Inst>
|
||||
using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type;
|
||||
|
||||
// TODO: in below permute pattern, the last 3 dim is within wave
|
||||
enum class matrix_core_permute_style
|
||||
{
|
||||
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
|
||||
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
|
||||
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
};
|
||||
|
||||
// assume this is B matrix, originally we have batch*n*k
|
||||
// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
|
||||
// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4)
|
||||
//
|
||||
// 4(waves) 32(mfma_m lane)
|
||||
// | |
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading)
|
||||
// nr kr |
|
||||
// nr 4 32 kr 2 8 2(klane)
|
||||
//
|
||||
// permute: 0,1,4,2,5,3,6
|
||||
// or
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading)
|
||||
// permute: 0,1,2,4,5,3,6
|
||||
//
|
||||
// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling
|
||||
// for simplicity, only consider n/k is multiple of block-size
|
||||
|
||||
// independend host arg with no template
|
||||
struct matrix_core_swizzle_host_args
|
||||
{
|
||||
const void* p_src;
|
||||
void* p_dst;
|
||||
int32_t batch;
|
||||
int32_t n;
|
||||
int32_t k;
|
||||
};
|
||||
|
||||
// NOTE: this kernel could follow the style of generic permute kernel
|
||||
// but here we pass in fixed layout as template arg and generate different kernel instance
|
||||
// purposely
|
||||
template <int BLOCK_SIZE_ = 256,
|
||||
int NPerBlock_ = 256,
|
||||
int KPerBlock_ = 128,
|
||||
matrix_core_permute_style pstyle_ =
|
||||
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2,
|
||||
matrix_core_inst_enum Inst_ = matrix_core_inst_enum::MFMA_32x32x8_F16>
|
||||
struct matrix_core_swizzle_kernel
|
||||
{
|
||||
using karg = matrix_core_swizzle_host_args;
|
||||
using harg = matrix_core_swizzle_host_args;
|
||||
|
||||
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
|
||||
static constexpr int WavesPerBlock_N = 4;
|
||||
static constexpr int WavesPerBlock_K = 1;
|
||||
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
|
||||
static constexpr int NPerBlock = NPerBlock_;
|
||||
static constexpr int KPerBlock = KPerBlock_;
|
||||
static constexpr matrix_core_permute_style pstyle = pstyle_;
|
||||
static constexpr matrix_core_inst_enum Inst = Inst_;
|
||||
|
||||
static constexpr ck_tile::index_t Alignment = 8;
|
||||
karg a;
|
||||
dim3 grids;
|
||||
|
||||
using WarpGemm = to_warp_gemm_t<Inst>;
|
||||
|
||||
__host__ matrix_core_swizzle_kernel(harg h)
|
||||
{
|
||||
a = h;
|
||||
ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock;
|
||||
ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock;
|
||||
grids = dim3(ks, ns, h.batch);
|
||||
}
|
||||
|
||||
__host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; }
|
||||
|
||||
__host__ void operator()(const ck_tile::stream_config& s) const
|
||||
{
|
||||
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
|
||||
}
|
||||
|
||||
struct kernel
|
||||
{
|
||||
__device__ static constexpr auto get_src_dist()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
constexpr index_t K2 = Alignment;
|
||||
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
|
||||
|
||||
static_assert(NPerBlock % (N1 * N2) == 0);
|
||||
static_assert(KPerBlock % (K1 * K2) == 0);
|
||||
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2);
|
||||
constexpr index_t N0 = NPerBlock / (N1 * N2);
|
||||
|
||||
// clang-format off
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0
|
||||
// 1 2 3 4 5 6
|
||||
tuple<sequence<N0>, sequence<N1>, sequence<N2>, sequence<K0>, sequence<K1>, sequence<K2>>,
|
||||
|
||||
// N1 K1 N2
|
||||
tuple<sequence<2>, sequence<5, 3>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
|
||||
// N0 K0 K2
|
||||
sequence<1, 4, 6>,
|
||||
sequence<0, 0, 0>>{});
|
||||
// clang-format on
|
||||
}
|
||||
__device__ static constexpr auto get_dst_dist()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
constexpr index_t K2 = Alignment;
|
||||
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
|
||||
|
||||
static_assert(NPerBlock % (N1 * N2) == 0);
|
||||
static_assert(KPerBlock % (K1 * K2) == 0);
|
||||
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2);
|
||||
constexpr index_t N0 = NPerBlock / (N1 * N2);
|
||||
|
||||
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
|
||||
{
|
||||
// clang-format off
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0
|
||||
// 1 2 3 4 5 6
|
||||
tuple<sequence<N0>, sequence<K0>, sequence<N1>, sequence<K1>, sequence<N2>, sequence<K2>>,
|
||||
|
||||
// N1 K1 N2
|
||||
tuple<sequence<3>, sequence<4, 5>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
|
||||
// N0 K0 K2
|
||||
sequence<1, 2, 6>,
|
||||
sequence<0, 0, 0>>{});
|
||||
// clang-format on
|
||||
}
|
||||
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
|
||||
{
|
||||
// clang-format off
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0
|
||||
// 1 2 3 4 5 6
|
||||
tuple<sequence<N0>, sequence<N1>, sequence<K0>, sequence<K1>, sequence<N2>, sequence<K2>>,
|
||||
|
||||
// N1 K1 N2
|
||||
tuple<sequence<2>, sequence<4, 5>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
|
||||
// N0 K0 K2
|
||||
sequence<1, 3, 6>,
|
||||
sequence<0, 0, 0>>{});
|
||||
// clang-format on
|
||||
}
|
||||
else
|
||||
{
|
||||
// clang-format off
|
||||
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
|
||||
constexpr index_t Kv = Alignment;
|
||||
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
|
||||
static_assert(KPerBlock % (K1 * K2) == 0);
|
||||
constexpr index_t Nr = NPerBlock / Nw;
|
||||
constexpr index_t Kr = KPerBlock / (Kv * Kw);
|
||||
|
||||
constexpr index_t Nr_p = WavesPerBlock_N;
|
||||
constexpr index_t Kr_p = WavesPerBlock_K;
|
||||
constexpr index_t Nr_y = Nr / Nr_p;
|
||||
constexpr index_t Kr_y = Kr / Kr_p;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
#if MERGE_2D_013425
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0 R
|
||||
// major 1 2
|
||||
// minor 0 1 2 0 1 2 3
|
||||
tuple<sequence<Nr_y, Nr_p, Nw>, sequence<Kr_y, Kr_p, Kw, Kv>>, // H
|
||||
|
||||
// Nr_p, Kr_p Kw Nw
|
||||
tuple<sequence<1 , 2>, sequence<2, 1>>, // p major
|
||||
tuple<sequence<1 , 1>, sequence<2, 2>>, // p minor
|
||||
|
||||
// Nr_y Kr_y Kv
|
||||
sequence<1, 2, 2>, // Y major
|
||||
sequence<0, 0, 3>>{}); // y minor
|
||||
#else
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0 R
|
||||
// major 1 2 3
|
||||
// minor 0 1 0 1 0 1 2
|
||||
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>, // H
|
||||
|
||||
// Nr_p, Kr_p Kw Nw
|
||||
tuple<sequence<1 , 2>, sequence<3, 3>>, // p major
|
||||
tuple<sequence<1 , 1>, sequence<0, 1>>, // p minor
|
||||
|
||||
// Nr_y Kr_y Kv
|
||||
sequence<1, 2, 3>, // Y major
|
||||
sequence<0, 0, 2>>{}); // y minor
|
||||
#endif
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void operator()(karg a_)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
index_t i_k = blockIdx.x;
|
||||
index_t i_n = blockIdx.y;
|
||||
index_t i_b = blockIdx.z;
|
||||
|
||||
constexpr index_t k2 = Alignment;
|
||||
constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t n1 = BLOCK_SIZE / get_warp_size();
|
||||
const index_t k0 = a_.k / (k1 * k2);
|
||||
const index_t n0 = a_.n / (n1 * n2);
|
||||
|
||||
constexpr index_t k2_tile = Alignment;
|
||||
constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size();
|
||||
constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile);
|
||||
constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile);
|
||||
|
||||
const fp16_t* p_src = reinterpret_cast<const fp16_t*>(a_.p_src) + i_b * a_.k * a_.n;
|
||||
fp16_t* p_dst = reinterpret_cast<fp16_t*>(a_.p_dst) + i_b * a_.k * a_.n;
|
||||
|
||||
const auto src_view = [&]() {
|
||||
const auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_src,
|
||||
make_tuple(n0, n1, n2, k0, k1, k2),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
const auto src_window = make_tile_window(src_view,
|
||||
make_tuple(number<n0_tile>{},
|
||||
number<n1_tile>{},
|
||||
number<n2_tile>{},
|
||||
number<k0_tile>{},
|
||||
number<k1_tile>{},
|
||||
number<k2_tile>{}),
|
||||
{i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0},
|
||||
get_src_dist());
|
||||
|
||||
auto dst_view = [&]() {
|
||||
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
|
||||
{
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(n0, k0, n1, k1, n2, k2),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
}
|
||||
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
|
||||
{
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(n0, n1, k0, k1, n2, k2),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
}
|
||||
else
|
||||
{
|
||||
#if MERGE_2D_013425
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
// constexpr index_t waveflatten = kw*nw*kv;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(nr, kr, number<kw>{}, number<nw>{}, number<kv>{}),
|
||||
number<Alignment>{}); // control vector load
|
||||
auto tmp_1 = transform_tensor_view(
|
||||
tmp,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(nr, number<nw>{})),
|
||||
make_merge_transform(make_tuple(kr, number<kw>{}, number<kv>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return tmp_1;
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten = kw * nw * kv;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(nr, kr, waveflatten),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
}();
|
||||
|
||||
auto dst_window = [&]() {
|
||||
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
|
||||
{
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<n0_tile>{},
|
||||
number<k0_tile>{},
|
||||
number<n1_tile>{},
|
||||
number<k1_tile>{},
|
||||
number<n2_tile>{},
|
||||
number<k2_tile>{}),
|
||||
{i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0},
|
||||
get_dst_dist());
|
||||
}
|
||||
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
|
||||
{
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<n0_tile>{},
|
||||
number<n1_tile>{},
|
||||
number<k0_tile>{},
|
||||
number<k1_tile>{},
|
||||
number<n2_tile>{},
|
||||
number<k2_tile>{}),
|
||||
{i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0},
|
||||
get_dst_dist());
|
||||
}
|
||||
else
|
||||
{
|
||||
#if MERGE_2D_013425
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{i_n * NPerBlock, i_k * KPerBlock},
|
||||
get_dst_dist());
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten_tile = kw * nw * kv;
|
||||
constexpr index_t nr_tile = NPerBlock / nw;
|
||||
constexpr index_t kr_tile = KPerBlock / (kw * kv);
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<nr_tile>{},
|
||||
number<kr_tile>{},
|
||||
number<waveflatten_tile>{}),
|
||||
{i_n * nr_tile, i_k * kr_tile, 0},
|
||||
get_dst_dist());
|
||||
#endif
|
||||
}
|
||||
}();
|
||||
|
||||
// actual load store
|
||||
auto src_tile = load_tile(src_window);
|
||||
|
||||
// now we only swap the distribution from src to dst, no extra movement occurs
|
||||
auto dst_tile = make_static_distributed_tensor<fp16_t>(get_dst_dist());
|
||||
dst_tile.get_thread_buffer() = src_tile.get_thread_buffer();
|
||||
|
||||
// final store
|
||||
store_tile(dst_window, dst_tile);
|
||||
}
|
||||
};
|
||||
};
|
||||
411
example/ck_tile/06_permute/permute.cpp
Normal file
411
example/ck_tile/06_permute/permute.cpp
Normal file
@@ -0,0 +1,411 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "permute.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
|
||||
#include "alternative_impl/matrix_core_swizzle.hpp"
|
||||
#endif
|
||||
|
||||
namespace detail {
|
||||
template <int bytes>
|
||||
struct to_integer_type;
|
||||
|
||||
template <>
|
||||
struct to_integer_type<4>
|
||||
{
|
||||
using type = int32_t;
|
||||
};
|
||||
template <>
|
||||
struct to_integer_type<2>
|
||||
{
|
||||
using type = int16_t;
|
||||
};
|
||||
template <>
|
||||
struct to_integer_type<1>
|
||||
{
|
||||
using type = int8_t;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int bytes>
|
||||
using to_integer_type = typename detail::to_integer_type<bytes>::type;
|
||||
|
||||
// host API (shoule come from codegen)
|
||||
float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp8") == 0)
|
||||
{
|
||||
using DataType = ck_tile::fp8_t;
|
||||
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
|
||||
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
using DataType = ck_tile::half_t;
|
||||
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
|
||||
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.data_type.compare("fp32") == 0)
|
||||
{
|
||||
using DataType = float;
|
||||
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
|
||||
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)")
|
||||
.insert("shape", "2,3,4", "the shape of the input tensor")
|
||||
.insert("perm", "2,1,0", "permute perm")
|
||||
.insert("kname", "0", "t to 1 will print kernel name")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
}
|
||||
|
||||
// "1,2,3,4" -> vector{1,2,3,4}
|
||||
std::vector<ck_tile::index_t> decode_vec(std::string q_val)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
std::string::size_type pos = 0;
|
||||
std::vector<ck_tile::index_t> v;
|
||||
while(true)
|
||||
{
|
||||
auto found = q_val.find(',', pos);
|
||||
ck_tile::index_t n =
|
||||
_S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos));
|
||||
v.push_back(n);
|
||||
if(found == std::string::npos)
|
||||
{
|
||||
break;
|
||||
}
|
||||
pos = found + 1;
|
||||
}
|
||||
return v;
|
||||
#undef _S2I_
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
|
||||
auto shape = decode_vec(arg_parser.get_str("shape"));
|
||||
auto perm = decode_vec(arg_parser.get_str("perm"));
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
int seed = arg_parser.get_int("seed");
|
||||
|
||||
assert(shape.size() == perm.size());
|
||||
ck_tile::index_t rank = perm.size();
|
||||
if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks)
|
||||
{
|
||||
printf("rank %d permute is not support yet\n", rank);
|
||||
return false;
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<DataType> x(shape);
|
||||
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
|
||||
|
||||
std::vector<ck_tile::index_t> y_shape = [&]() {
|
||||
std::vector<ck_tile::index_t> tmp(rank, 0);
|
||||
// std::cout << "@@@@" << tmp << std::endl;
|
||||
for(int i = 0; i < static_cast<int>(rank); i++)
|
||||
{
|
||||
// std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" <<
|
||||
// static_cast<int>(rank)
|
||||
// << std::endl;
|
||||
tmp[i] = shape[perm[i]];
|
||||
}
|
||||
// std::cout << "@@@" << tmp << std::endl;
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
ck_tile::HostTensor<DataType> y(y_shape);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x.data());
|
||||
|
||||
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm
|
||||
<< std::flush;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
stream_warmup,
|
||||
stream_repeat};
|
||||
float ave_time = 0.f;
|
||||
auto run_permute = [&]() {
|
||||
permute_traits t;
|
||||
t.data_type = data_type;
|
||||
|
||||
permute_args a;
|
||||
a.p_src = x_buf.GetDeviceBuffer();
|
||||
a.p_dst = y_buf.GetDeviceBuffer();
|
||||
a.rank = rank;
|
||||
std::copy(shape.begin(), shape.end(), a.shape);
|
||||
std::copy(perm.begin(), perm.end(), a.perm);
|
||||
|
||||
return permute(t, a, stream_config);
|
||||
};
|
||||
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
|
||||
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
|
||||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") ||
|
||||
arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")))
|
||||
{
|
||||
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
|
||||
{
|
||||
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
matrix_core_swizzle_traits t;
|
||||
t.data_type = data_type;
|
||||
t.permute = arg_parser.get_str("perm");
|
||||
|
||||
matrix_core_swizzle_args a;
|
||||
a.p_src = x_buf.GetDeviceBuffer();
|
||||
a.p_dst = y_buf.GetDeviceBuffer();
|
||||
a.batch = shape[0];
|
||||
|
||||
auto nr = shape[1];
|
||||
auto nw = shape[2];
|
||||
auto kr = shape[3];
|
||||
auto kw = shape[4];
|
||||
auto kv = shape[5];
|
||||
a.n = nr * nw;
|
||||
a.k = kr * kw * kv;
|
||||
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
|
||||
{
|
||||
t.inst = "16x16x16";
|
||||
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
|
||||
{
|
||||
t.inst = "32x32x8";
|
||||
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = run_permute();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
matrix_core_swizzle_traits t;
|
||||
t.data_type = data_type;
|
||||
t.permute = arg_parser.get_str("perm");
|
||||
|
||||
matrix_core_swizzle_args a;
|
||||
a.p_src = x_buf.GetDeviceBuffer();
|
||||
a.p_dst = y_buf.GetDeviceBuffer();
|
||||
a.batch = shape[0];
|
||||
a.n = shape[1] * shape[2] * shape[3];
|
||||
a.k = shape[4] * shape[5] * shape[6];
|
||||
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
|
||||
shape[4] % 8 == 0 && shape[1] % 2 == 0)
|
||||
{
|
||||
// 32x32x8 inst
|
||||
// perm=0,1,4,2,5,3,6
|
||||
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
|
||||
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
|
||||
|
||||
t.inst = "32x32x8";
|
||||
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
|
||||
shape[4] % 4 == 0 && shape[1] % 4 == 0)
|
||||
{
|
||||
// 16x16x16 inst
|
||||
// perm=0,1,4,2,5,3,6
|
||||
// y_shape=*,4x,4x,4,4,16,8
|
||||
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
|
||||
t.inst = "16x16x16";
|
||||
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = run_permute();
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
ave_time = run_permute();
|
||||
}
|
||||
std::cout << ", time:" << ave_time << "ms" << std::flush;
|
||||
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
reference_permute(x, y, perm);
|
||||
#if 0
|
||||
if constexpr (std::is_same_v<float, DataType>){
|
||||
// using itype = to_integer_type<sizeof(DataType)>;
|
||||
fflush(stdout);
|
||||
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
|
||||
printf("%3.0f ", x.mData[zz]);
|
||||
}
|
||||
printf("->\n");
|
||||
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
|
||||
printf("%3.0f ", y.mData[zz]);
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
#endif
|
||||
ck_tile::HostTensor<DataType> y_dev(y.get_lengths());
|
||||
|
||||
y_buf.FromDevice(y_dev.data());
|
||||
|
||||
pass = std::equal(
|
||||
y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) {
|
||||
using itype = to_integer_type<sizeof(DataType)>;
|
||||
itype i_d = ck_tile::bit_cast<itype>(d);
|
||||
itype i_h = ck_tile::bit_cast<itype>(h);
|
||||
return i_d == i_h;
|
||||
});
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp32")
|
||||
{
|
||||
return run<float>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
19
example/ck_tile/06_permute/permute.hpp
Normal file
19
example/ck_tile/06_permute/permute.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/permute.hpp"
|
||||
#include <string>
|
||||
|
||||
struct permute_traits
|
||||
{
|
||||
std::string data_type;
|
||||
};
|
||||
|
||||
using permute_args = ck_tile::GenericPermuteHostArgs;
|
||||
|
||||
// host API
|
||||
float permute(permute_traits, permute_args, const ck_tile::stream_config&);
|
||||
34
example/ck_tile/06_permute/script/smoke_test.sh
Normal file
34
example/ck_tile/06_permute/script/smoke_test.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_permute
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
# mode=0
|
||||
# export HIP_VISIBLE_DEVICES=4
|
||||
if [ $# -ge 1 ] ; then
|
||||
set -x
|
||||
fi
|
||||
|
||||
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=2,8,16,8,4,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=1,24,32,16,2,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
|
||||
|
||||
echo "------------------------------------------------------------------"
|
||||
|
||||
for prec in "fp8" "fp16" "fp32" ; do
|
||||
|
||||
$EXE -prec=$prec -shape=3,8 -perm=1,0 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=48,6,8 -perm=2,1,0 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=24,128,3 -perm=0,2,1 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=4,10,7,6 -perm=0,2,3,1 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=8,24,36,10 -perm=3,1,2,0 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=8,1,36,4 -perm=2,1,0,3 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=5,10,16,2,36,4 -perm=4,5,2,1,0,3 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=2,32,8,3,6,2,5,4 -perm=5,2,4,7,1,6,3,0 $COMMON_ARGS
|
||||
echo "------------------------------------------------------------------"
|
||||
done
|
||||
@@ -8,6 +8,6 @@ add_subdirectory(03_gemm)
|
||||
add_subdirectory(04_img2col)
|
||||
add_subdirectory(05_reduce)
|
||||
add_subdirectory(06_rmsnorm2d)
|
||||
add_subdirectory(06_permute)
|
||||
add_subdirectory(07_add_rmsnorm2d_rdquant)
|
||||
add_subdirectory(09_topk_softmax)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
|
||||
replicate,
|
||||
xor_t,
|
||||
offset,
|
||||
indexing,
|
||||
};
|
||||
|
||||
template <index_t NDimLow, index_t NDimUp>
|
||||
@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
struct indexing : public base_transform<1, 1>
|
||||
{
|
||||
static constexpr index_t NDimUp = 1;
|
||||
|
||||
using LowerIndex = multi_index<1>;
|
||||
using UpperIndex = multi_index<1>;
|
||||
|
||||
using UpLengths = decltype(make_tuple(UpLength{}));
|
||||
UpLengths up_lengths_;
|
||||
IndexingAdaptor iadaptor_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
|
||||
const IndexingAdaptor& iadaptor)
|
||||
: up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
|
||||
{
|
||||
return coord_transform_enum::indexing;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
|
||||
"wrong! inconsistent # of dimension");
|
||||
iadaptor_.calculate_lower_index(idx_low, idx_up);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
|
||||
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_always_mapped_to_valid_lower_index()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
|
||||
IndexingAdaptor::is_known_at_compile_time();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("embed{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
//*******************************************************************************************************
|
||||
|
||||
template <typename LowLength>
|
||||
@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename UpLength, typename Indices>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
|
||||
const Indices& indices)
|
||||
{
|
||||
// by default we use the simplest one
|
||||
return indexing<UpLength, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>>{
|
||||
up_lengths, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>{indices}};
|
||||
}
|
||||
|
||||
template <typename UpLength, typename IndexingAdaptor>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
|
||||
{
|
||||
return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
60
include/ck_tile/core/algorithm/indexing_adaptor.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// pre-defined indexing adaptor used for indexing(scatter/gather)
|
||||
|
||||
// this version cache the index inside thread register(which is also prefered in real senario)
|
||||
// however it's user's responsibility that each thread only provide one indexing, which means
|
||||
// move coordinate will not change on this dim
|
||||
template <typename IndexingType>
|
||||
struct indexing_adaptor_onshot_cached
|
||||
{
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx)
|
||||
: cached_idx_(idx)
|
||||
{
|
||||
}
|
||||
IndexingType cached_idx_;
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(number<0>{}) = cached_idx_;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& /*idx_low*/,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
|
||||
|
||||
// pass the diff to lower, but not changing the actually index
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<IndexingType>::value;
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_permute.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
|
||||
|
||||
57
include/ck_tile/host/reference/reference_permute.hpp
Normal file
57
include/ck_tile/host/reference/reference_permute.hpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
this will do permute + contiguous like functionality in pytorch
|
||||
*/
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST void
|
||||
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
|
||||
{
|
||||
const auto x_len = x.mDesc.get_lengths();
|
||||
const auto y_len = y.mDesc.get_lengths();
|
||||
assert(x_len.size() == y_len.size());
|
||||
index_t rank = x_len.size();
|
||||
const auto x_elm = std::accumulate(x_len.begin(), x_len.end(), 1, std::multiplies<index_t>());
|
||||
const auto y_elm = std::accumulate(y_len.begin(), y_len.end(), 1, std::multiplies<index_t>());
|
||||
assert(x_elm == y_elm);
|
||||
(void)y_elm;
|
||||
|
||||
auto f = [&](auto i_element) {
|
||||
std::vector<size_t> y_coord = [&]() {
|
||||
std::vector<size_t> tmp(rank, 0);
|
||||
size_t r = i_element;
|
||||
for(index_t i = rank - 1; i >= 0; i--)
|
||||
{
|
||||
tmp[i] = r % y_len[i];
|
||||
r = r / y_len[i];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
std::vector<size_t> x_coord = [&]() {
|
||||
std::vector<size_t> tmp(rank, 0);
|
||||
for(index_t i = 0; i < rank; i++)
|
||||
{
|
||||
tmp[dims[i]] = y_coord[i];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
// do permute
|
||||
y(y_coord) = x(x_coord);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
8
include/ck_tile/ops/permute.hpp
Normal file
8
include/ck_tile/ops/permute.hpp
Normal file
@@ -0,0 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
|
||||
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
169
include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
Normal file
169
include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
// #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/* independent host side argument, no template
|
||||
*/
|
||||
struct GenericPermuteHostArgs
|
||||
{
|
||||
static constexpr index_t kMaxRanks = 8; // TODO: hardcoded
|
||||
|
||||
const void* p_src;
|
||||
void* p_dst;
|
||||
index_t rank;
|
||||
index_t shape[kMaxRanks]; // input shape
|
||||
index_t perm[kMaxRanks]; // permute index
|
||||
};
|
||||
|
||||
/*
|
||||
simulate torch.permute:
|
||||
x_ = x_.view(x.shape[0],
|
||||
x.shape[1]//16, 16,
|
||||
x.shape[2]//32, 4, 8)
|
||||
x_ = x_.permute(0,1,3,4,2,5)
|
||||
x_ = x_.contiguous()
|
||||
x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);//
|
||||
|
||||
this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks
|
||||
dim of permutation, with a single kernel
|
||||
|
||||
*/
|
||||
template <typename Problem_>
|
||||
struct GenericPermute
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
|
||||
using DataType = remove_cvref_t<typename Problem::DataType>;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMaxRanks = Problem::kMaxRanks;
|
||||
static constexpr bool KeepLastDim = Problem::KeepLastDim;
|
||||
|
||||
struct __attribute__((packed)) Kargs
|
||||
{
|
||||
const void* p_src;
|
||||
void* p_dst;
|
||||
// index_t rank;
|
||||
index_t num_elements;
|
||||
index_t perm_length[kMaxRanks]; // tensor length after permutation
|
||||
index_t perm_stride[kMaxRanks]; // tensor stride after permutation
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr index_t TotalElements(const GenericPermuteHostArgs& h)
|
||||
{
|
||||
index_t n = 1;
|
||||
for(auto i = 0; i < h.rank; i++)
|
||||
{
|
||||
n *= h.shape[i];
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const GenericPermuteHostArgs& h)
|
||||
{
|
||||
Kargs a;
|
||||
a.p_src = h.p_src;
|
||||
a.p_dst = h.p_dst;
|
||||
|
||||
// assert rank <= kMaxRanks
|
||||
index_t i = 0;
|
||||
|
||||
index_t perm[kMaxRanks];
|
||||
index_t x_shape[kMaxRanks];
|
||||
index_t x_stride[kMaxRanks];
|
||||
// index_t perm_length[kMaxRanks];
|
||||
|
||||
for(; i < h.rank; i++)
|
||||
{
|
||||
x_shape[i] = h.shape[i];
|
||||
perm[i] = h.perm[i];
|
||||
}
|
||||
for(; i < kMaxRanks; i++)
|
||||
{
|
||||
x_shape[i] = 1;
|
||||
perm[i] = i; // will index to len = 1
|
||||
}
|
||||
|
||||
index_t stride = 1;
|
||||
for(index_t j = kMaxRanks - 1; j >= 0; j--)
|
||||
{
|
||||
x_stride[j] = stride;
|
||||
stride *= x_shape[j];
|
||||
}
|
||||
|
||||
for(index_t j = 0; j < kMaxRanks; j++)
|
||||
{
|
||||
a.perm_length[j] = x_shape[perm[j]];
|
||||
a.perm_stride[j] = x_stride[perm[j]];
|
||||
}
|
||||
|
||||
a.num_elements = TotalElements(h);
|
||||
return a;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(GenericPermuteHostArgs h)
|
||||
{
|
||||
auto total = TotalElements(h);
|
||||
auto grids = dim3((total + BlockSize() - 1) / BlockSize());
|
||||
// printf("### total:%d, grids:%dx%dx%d\n", total, );
|
||||
return grids;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t id = blockIdx.x * BlockSize() + threadIdx.x;
|
||||
|
||||
if(id >= kargs.num_elements)
|
||||
return;
|
||||
|
||||
const auto perm_length =
|
||||
generate_tuple([&](auto I) { return kargs.perm_length[I]; }, number<kMaxRanks>{});
|
||||
const auto perm_stride =
|
||||
generate_tuple([&](auto I) { return kargs.perm_stride[I]; }, number<kMaxRanks>{});
|
||||
|
||||
const DataType* p_src = reinterpret_cast<const DataType*>(kargs.p_src);
|
||||
DataType* p_dst = reinterpret_cast<DataType*>(kargs.p_dst);
|
||||
|
||||
const auto src_view_0 = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_src, perm_length, perm_stride, number<1>{}, number<1>{});
|
||||
|
||||
const auto src_view = transform_tensor_view(
|
||||
src_view_0,
|
||||
make_tuple(make_merge_transform(perm_length)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
auto dst_view_0 = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst, perm_length, number<1>{});
|
||||
|
||||
auto dst_view = transform_tensor_view(
|
||||
dst_view_0,
|
||||
make_tuple(make_merge_transform(perm_length)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
// TODO: hard code to vector 1
|
||||
using vector_t = thread_buffer<DataType, 1>;
|
||||
|
||||
const auto src_coord =
|
||||
make_tensor_coordinate(src_view.get_tensor_descriptor(), array<index_t, 1>{id});
|
||||
const auto dst_coord =
|
||||
make_tensor_coordinate(dst_view.get_tensor_descriptor(), array<index_t, 1>{id});
|
||||
|
||||
// printf("src id:%d, os:%d\n", id, src_coord.get_offset());
|
||||
// printf("dst id:%d, os:%d\n", id, dst_coord.get_offset());
|
||||
|
||||
const vector_t x = src_view.template get_vectorized_elements<vector_t>(src_coord, 0);
|
||||
dst_view.template set_vectorized_elements<vector_t>(dst_coord, 0, x);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,28 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_,
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kMaxRanks_ = 8,
|
||||
bool KeepLastDim_ = false>
|
||||
struct GenericPermuteProblem
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMaxRanks = kMaxRanks_;
|
||||
/* KeepLastDim:
|
||||
* if last dim keep the same? this can help enable vector load
|
||||
* permute(0, 2, 4, 1, 3, 5) -> true
|
||||
* permute(0, 3, 2, 1) -> false
|
||||
*/
|
||||
static constexpr bool KeepLastDim = KeepLastDim_;
|
||||
// TODO: not used(?)
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -210,3 +210,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL
|
||||
add_subdirectory(smfmac_op)
|
||||
endif()
|
||||
add_subdirectory(position_embedding)
|
||||
add_subdirectory(scatter_gather)
|
||||
|
||||
2
test/scatter_gather/CMakeLists.txt
Normal file
2
test/scatter_gather/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_test_executable(test_scatter_gather scatter_gather.cpp)
|
||||
# target_compile_options(test_scatter_gather PRIVATE -v --save-temps -Wno-gnu-line-marker)
|
||||
276
test/scatter_gather/scatter_gather.cpp
Normal file
276
test/scatter_gather/scatter_gather.cpp
Normal file
@@ -0,0 +1,276 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#ifndef TEST_SCATTER_GATHER_VERBOSE
|
||||
#define TEST_SCATTER_GATHER_VERBOSE 1
|
||||
#endif
|
||||
|
||||
#define HIP_CALL(call) \
|
||||
do \
|
||||
{ \
|
||||
hipError_t err = call; \
|
||||
if(err != hipSuccess) \
|
||||
{ \
|
||||
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \
|
||||
exit(0); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
/*
|
||||
TODO:
|
||||
This is a simple design of scatter/gather through indexing transform, with limitations
|
||||
We may design a scatter/gather adaptor layer directly inside tile window
|
||||
*/
|
||||
template <ck_tile::index_t ROW_TILE_SIZE = 8,
|
||||
ck_tile::index_t COL_TILE_SIZE = 32 * 8,
|
||||
ck_tile::index_t BLOCK_SIZE = 256,
|
||||
ck_tile::index_t ALIGNMENT = 8,
|
||||
typename INDEX_BUF_TYPE = ck_tile::index_t,
|
||||
typename DATA_TYPE = ck_tile::fp16_t>
|
||||
__global__ void row_scatter_gather(const INDEX_BUF_TYPE* src_row_idx_ptr,
|
||||
const INDEX_BUF_TYPE* dst_row_idx_ptr,
|
||||
const DATA_TYPE* src_ptr,
|
||||
DATA_TYPE* dst_ptr,
|
||||
ck_tile::index_t n_row_total,
|
||||
ck_tile::index_t /*n_row_select*/,
|
||||
ck_tile::index_t n_cols)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// some constexpr vars
|
||||
constexpr index_t vec = ALIGNMENT;
|
||||
static_assert(COL_TILE_SIZE % vec == 0);
|
||||
constexpr index_t col_lanes = COL_TILE_SIZE / vec;
|
||||
constexpr index_t warp_size = ck_tile::get_warp_size();
|
||||
static_assert(warp_size % col_lanes == 0);
|
||||
constexpr index_t row_lanes = warp_size / col_lanes;
|
||||
constexpr index_t num_warps = BLOCK_SIZE / warp_size;
|
||||
static_assert(ROW_TILE_SIZE % (num_warps * row_lanes) == 0);
|
||||
constexpr index_t row_repeat = ROW_TILE_SIZE / (num_warps * row_lanes);
|
||||
static_assert(
|
||||
row_repeat == 1,
|
||||
"currently indexing not support(and would be not performant) if row_repeat has more");
|
||||
|
||||
// tile partitioner
|
||||
index_t tile_col_idx = 0;
|
||||
index_t tile_row_idx = blockIdx.x * ROW_TILE_SIZE;
|
||||
|
||||
// create our tild distribution, which tell us the location of different threads
|
||||
constexpr auto src_dist = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<row_repeat, num_warps, row_lanes>, sequence<col_lanes, vec>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
const auto coord = src_dist.calculate_index();
|
||||
const auto row_coord = coord[number<0>{}] + tile_row_idx;
|
||||
|
||||
// load the current row index from the indexing buffer. we do not use ck_tile utility here
|
||||
INDEX_BUF_TYPE src_row_id = src_row_idx_ptr[row_coord];
|
||||
INDEX_BUF_TYPE dst_row_id = dst_row_idx_ptr[row_coord];
|
||||
|
||||
// printf("-- tid:%d, src_row_id:%d, dst_row_id:%d\n", static_cast<int>(threadIdx.x),
|
||||
// static_cast<int>(src_row_id), static_cast<int>(dst_row_id));
|
||||
|
||||
const auto src_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(src_ptr,
|
||||
make_tuple(n_row_total, n_cols),
|
||||
make_tuple(n_cols, 1),
|
||||
number<vec>{}, // alignement
|
||||
number<1>{});
|
||||
|
||||
const auto src_gather_view = transform_tensor_view(
|
||||
src_view,
|
||||
make_tuple(make_indexing_transform(
|
||||
n_row_total,
|
||||
src_row_id), // here we replace row_idx which is loaded from another buffer
|
||||
make_pass_through_transform(n_cols)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto src_tile = make_tile_window(src_gather_view,
|
||||
make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}),
|
||||
{tile_row_idx, tile_col_idx},
|
||||
src_dist);
|
||||
|
||||
const auto dst_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(dst_ptr,
|
||||
make_tuple(n_row_total, n_cols),
|
||||
make_tuple(n_cols, 1),
|
||||
number<vec>{},
|
||||
number<1>{});
|
||||
|
||||
const auto dst_scatter_view = transform_tensor_view(
|
||||
dst_view,
|
||||
make_tuple(make_indexing_transform(
|
||||
n_row_total,
|
||||
dst_row_id), // here we replace row_idx which is loaded from another buffer
|
||||
make_pass_through_transform(n_cols)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto dst_tile = make_tile_window(dst_scatter_view,
|
||||
make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}),
|
||||
{tile_row_idx, tile_col_idx},
|
||||
src_dist /*reuse distribution*/);
|
||||
|
||||
// we finished descriptor construction and index calculation, now start load/store
|
||||
for(auto i = 0; i < n_cols; i += COL_TILE_SIZE)
|
||||
{
|
||||
// note that scatter/gather are just the same API when doing load store as normal memory
|
||||
// operation
|
||||
auto data = load_tile(src_tile);
|
||||
store_tile(dst_tile, data);
|
||||
|
||||
move_tile_window(src_tile, {number<0>{}, number<COL_TILE_SIZE>{}});
|
||||
move_tile_window(dst_tile, {number<0>{}, number<COL_TILE_SIZE>{}});
|
||||
}
|
||||
}
|
||||
|
||||
union pixel
|
||||
{
|
||||
struct __attribute__((packed))
|
||||
{
|
||||
unsigned int r : 6;
|
||||
unsigned int c : 10;
|
||||
};
|
||||
ushort data;
|
||||
};
|
||||
|
||||
struct unique_linear_rand
|
||||
{
|
||||
unique_linear_rand(int capacity_) : capacity(capacity_) {}
|
||||
std::unordered_set<int> set;
|
||||
int gen()
|
||||
{
|
||||
if(static_cast<int>(set.size()) >= capacity)
|
||||
{
|
||||
printf("overflow, but will give you an number as well\n");
|
||||
return std::rand() % capacity;
|
||||
}
|
||||
while(1)
|
||||
{
|
||||
int r = std::rand() % capacity;
|
||||
if(set.count(r) == 1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
set.insert(r);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
int capacity;
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
int row_total = 64;
|
||||
int row_select = 8 * 2;
|
||||
int col = 256 * 2;
|
||||
using fp16_t = ck_tile::fp16_t;
|
||||
|
||||
constexpr int row_tile = 8;
|
||||
constexpr int col_tile = 256;
|
||||
|
||||
fp16_t* src = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t)));
|
||||
for(int i_r = 0; i_r < row_total; i_r++)
|
||||
{
|
||||
for(int i_c = 0; i_c < col; i_c++)
|
||||
{
|
||||
int i = i_r * col + i_c;
|
||||
pixel p;
|
||||
p.r = i_r;
|
||||
p.c = i_c;
|
||||
ushort d = p.data;
|
||||
src[i] = ck_tile::bit_cast<fp16_t>(d); // for simplicity, just cast
|
||||
}
|
||||
}
|
||||
|
||||
fp16_t* dst = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t)));
|
||||
int* src_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int)));
|
||||
int* dst_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int)));
|
||||
// std::srand(std::time(std::nullptr));
|
||||
// std::srand(11935);
|
||||
std::srand(std::time(nullptr));
|
||||
auto src_gen = unique_linear_rand(row_total);
|
||||
auto dst_gen = unique_linear_rand(row_total); // dst index must be unique. src is fine
|
||||
for(int i_r = 0; i_r < row_select; i_r++)
|
||||
{
|
||||
src_idx[i_r] = src_gen.gen();
|
||||
dst_idx[i_r] = dst_gen.gen();
|
||||
}
|
||||
|
||||
void* dev_src;
|
||||
void* dev_dst;
|
||||
void* dev_src_idx;
|
||||
void* dev_dst_idx;
|
||||
HIP_CALL(hipMalloc(&dev_src, row_total * col * sizeof(fp16_t)));
|
||||
HIP_CALL(hipMalloc(&dev_dst, row_total * col * sizeof(fp16_t)));
|
||||
HIP_CALL(hipMalloc(&dev_src_idx, row_select * sizeof(int)));
|
||||
HIP_CALL(hipMalloc(&dev_dst_idx, row_select * sizeof(int)));
|
||||
|
||||
HIP_CALL(hipMemcpy(dev_src, src, row_total * col * sizeof(fp16_t), hipMemcpyHostToDevice));
|
||||
HIP_CALL(hipMemcpy(dev_src_idx, src_idx, row_select * sizeof(int), hipMemcpyHostToDevice));
|
||||
HIP_CALL(hipMemcpy(dev_dst_idx, dst_idx, row_select * sizeof(int), hipMemcpyHostToDevice));
|
||||
|
||||
constexpr int bdim = 256;
|
||||
int gdim = (row_select + row_tile - 1) / row_tile;
|
||||
row_scatter_gather<row_tile, col_tile><<<gdim, bdim>>>(reinterpret_cast<int*>(dev_src_idx),
|
||||
reinterpret_cast<int*>(dev_dst_idx),
|
||||
reinterpret_cast<fp16_t*>(dev_src),
|
||||
reinterpret_cast<fp16_t*>(dev_dst),
|
||||
row_total,
|
||||
row_select,
|
||||
col);
|
||||
|
||||
HIP_CALL(hipMemcpy(dst, dev_dst, row_total * col * sizeof(fp16_t), hipMemcpyDeviceToHost));
|
||||
|
||||
#if TEST_SCATTER_GATHER_VERBOSE
|
||||
printf("select row:");
|
||||
for(int i_r = 0; i_r < row_select; i_r++)
|
||||
{
|
||||
printf("%d->%d->%d ", i_r, src_idx[i_r], dst_idx[i_r]);
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
|
||||
int err_cnt = 0;
|
||||
for(int i_r = 0; i_r < row_select; i_r++)
|
||||
{
|
||||
for(int i_c = 0; i_c < col; i_c++)
|
||||
{
|
||||
int i = dst_idx[i_r] * col + i_c;
|
||||
pixel p = ck_tile::bit_cast<pixel>(dst[i]);
|
||||
bool is_ok = p.r == src_idx[i_r] && p.c == i_c;
|
||||
if(!is_ok)
|
||||
{
|
||||
if(i_c == 0)
|
||||
printf("(%d)pixel: %dx%d -> %d\n", i_r, p.r, p.c, dst_idx[i_r]);
|
||||
err_cnt++;
|
||||
}
|
||||
}
|
||||
}
|
||||
#if TEST_SCATTER_GATHER_VERBOSE
|
||||
printf("err:%d\n", err_cnt);
|
||||
#endif
|
||||
|
||||
free(src);
|
||||
free(dst);
|
||||
free(src_idx);
|
||||
free(dst_idx);
|
||||
return err_cnt == 0 ? 0 : -1;
|
||||
}
|
||||
Reference in New Issue
Block a user