mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] add generic_permute (#1607)
[ROCm/composable_kernel commit: 9fbd72e97e]
This commit is contained in:
13
example/ck_tile/06_permute/CMakeLists.txt
Normal file
13
example/ck_tile/06_permute/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp)
|
||||
|
||||
if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL)
|
||||
# set(PERMUTE_USE_ALTERNATIVE_IMPL false)
|
||||
set(PERMUTE_USE_ALTERNATIVE_IMPL true)
|
||||
endif()
|
||||
if(PERMUTE_USE_ALTERNATIVE_IMPL)
|
||||
target_compile_options(tile_example_permute PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL)
|
||||
target_sources(tile_example_permute PRIVATE alternative_impl/matrix_core_swizzle.cpp)
|
||||
endif()
|
||||
# target_compile_options(tile_example_permute PRIVATE -v --save-temps -Wno-gnu-line-marker)
|
||||
46
example/ck_tile/06_permute/README.md
Normal file
46
example/ck_tile/06_permute/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# permute
|
||||
|
||||
This folder contains example for permute kernel, which is similiar to [torch.permute](https://pytorch.org/docs/stable/generated/torch.permute.html) (combined with [torch.contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html)). Currently we implement a generic permute kernel that support up to rank 8 arbitrary permutation with a single kernel instance. Performance is not the first consideration, we prefer a simple and general kernel implementation using `ck_tile` in this example.
|
||||
|
||||
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-prec data type. fp16/bf16/fp32 (default:fp16)
|
||||
-shape the shape of the input tensor (default:2,3,4)
|
||||
-perm permute perm (default:2,1,0)
|
||||
```
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_permute -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_permute`
|
||||
|
||||
|
||||
## some examples
|
||||
```
|
||||
# torch
|
||||
x=torch.randn(2,3,4,6)
|
||||
y=x.permute(0,3,2,1).contiguous()
|
||||
|
||||
# ck_tile
|
||||
./build/bin/tile_example_permute -shape=2,3,4,6 -perm=0,3,2,1
|
||||
```
|
||||
|
||||
or you can try the smoke_test
|
||||
```
|
||||
# in the root of ck_tile, after you build this example
|
||||
sh example/ck_tile/06_permute/script/smoke_test.sh
|
||||
```
|
||||
|
||||
### alternative implementation
|
||||
we have an alternative implementation under `alternative_impl/` folder, that can swizzle the tensor to be more friendly for data loading for matrix core layout. This can be enabled when dealing with a `rank-7` tensor, with a fixed pattern of either `0,1,4,2,5,3,6` or `0,1,2,4,5,3,6`. There are other shape limitation of this implementation, check the source code of `permute.cpp` for detail.
|
||||
```
|
||||
# example
|
||||
./build/bin/tile_example_permute -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 # b_n0_k0_n1_k1_n2_k2
|
||||
./build/bin/tile_example_permute -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 # b_n0_n1_k0_k1_n2_k2
|
||||
```
|
||||
@@ -0,0 +1,98 @@
|
||||
#include "matrix_core_swizzle.hpp"
|
||||
#include "matrix_core_swizzle_kernel.hpp"
|
||||
|
||||
float matrix_core_swizzle(matrix_core_swizzle_traits t,
|
||||
matrix_core_swizzle_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
if(t.inst.compare("32x32x8") == 0)
|
||||
{
|
||||
constexpr int BLOCK_SIZE = 256;
|
||||
constexpr int NPerBlock = 256;
|
||||
constexpr int KPerBlock = 128;
|
||||
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
|
||||
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
}
|
||||
else if(t.inst.compare("16x16x16") == 0)
|
||||
{
|
||||
constexpr int BLOCK_SIZE = 256;
|
||||
constexpr int NPerBlock = 256;
|
||||
constexpr int KPerBlock = 128;
|
||||
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
|
||||
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "matrix_core_swizzle_kernel.hpp"
|
||||
#include <string>
|
||||
|
||||
struct matrix_core_swizzle_traits
|
||||
{
|
||||
std::string data_type; // fp16 only
|
||||
std::string inst; // 32x32x8, 16x16x16
|
||||
std::string permute; //
|
||||
};
|
||||
|
||||
using matrix_core_swizzle_args = matrix_core_swizzle_host_args;
|
||||
|
||||
// host API
|
||||
float matrix_core_swizzle(matrix_core_swizzle_traits,
|
||||
matrix_core_swizzle_args,
|
||||
const ck_tile::stream_config&);
|
||||
@@ -0,0 +1,413 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
// if set to 1, slightly more instructions generated to calculate address
|
||||
#ifndef MERGE_2D_013425
|
||||
#define MERGE_2D_013425 0
|
||||
#endif
|
||||
|
||||
enum class matrix_core_inst_enum
|
||||
{
|
||||
MFMA_32x32x8_F16 = 0,
|
||||
MFMA_16x16x16_F16 = 1,
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
template <matrix_core_inst_enum>
|
||||
struct to_warp_gemm;
|
||||
|
||||
template <>
|
||||
struct to_warp_gemm<matrix_core_inst_enum::MFMA_32x32x8_F16>
|
||||
{
|
||||
using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
|
||||
{
|
||||
using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16;
|
||||
};
|
||||
} // namespace detail
|
||||
template <matrix_core_inst_enum Inst>
|
||||
using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type;
|
||||
|
||||
// TODO: in below permute pattern, the last 3 dim is within wave
|
||||
enum class matrix_core_permute_style
|
||||
{
|
||||
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
|
||||
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
|
||||
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
};
|
||||
|
||||
// assume this is B matrix, originally we have batch*n*k
|
||||
// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
|
||||
// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4)
|
||||
//
|
||||
// 4(waves) 32(mfma_m lane)
|
||||
// | |
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading)
|
||||
// nr kr |
|
||||
// nr 4 32 kr 2 8 2(klane)
|
||||
//
|
||||
// permute: 0,1,4,2,5,3,6
|
||||
// or
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading)
|
||||
// permute: 0,1,2,4,5,3,6
|
||||
//
|
||||
// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling
|
||||
// for simplicity, only consider n/k is multiple of block-size
|
||||
|
||||
// independend host arg with no template
|
||||
struct matrix_core_swizzle_host_args
|
||||
{
|
||||
const void* p_src;
|
||||
void* p_dst;
|
||||
int32_t batch;
|
||||
int32_t n;
|
||||
int32_t k;
|
||||
};
|
||||
|
||||
// NOTE: this kernel could follow the style of generic permute kernel
|
||||
// but here we pass in fixed layout as template arg and generate different kernel instance
|
||||
// purposely
|
||||
template <int BLOCK_SIZE_ = 256,
|
||||
int NPerBlock_ = 256,
|
||||
int KPerBlock_ = 128,
|
||||
matrix_core_permute_style pstyle_ =
|
||||
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2,
|
||||
matrix_core_inst_enum Inst_ = matrix_core_inst_enum::MFMA_32x32x8_F16>
|
||||
struct matrix_core_swizzle_kernel
|
||||
{
|
||||
using karg = matrix_core_swizzle_host_args;
|
||||
using harg = matrix_core_swizzle_host_args;
|
||||
|
||||
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
|
||||
static constexpr int WavesPerBlock_N = 4;
|
||||
static constexpr int WavesPerBlock_K = 1;
|
||||
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
|
||||
static constexpr int NPerBlock = NPerBlock_;
|
||||
static constexpr int KPerBlock = KPerBlock_;
|
||||
static constexpr matrix_core_permute_style pstyle = pstyle_;
|
||||
static constexpr matrix_core_inst_enum Inst = Inst_;
|
||||
|
||||
static constexpr ck_tile::index_t Alignment = 8;
|
||||
karg a;
|
||||
dim3 grids;
|
||||
|
||||
using WarpGemm = to_warp_gemm_t<Inst>;
|
||||
|
||||
__host__ matrix_core_swizzle_kernel(harg h)
|
||||
{
|
||||
a = h;
|
||||
ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock;
|
||||
ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock;
|
||||
grids = dim3(ks, ns, h.batch);
|
||||
}
|
||||
|
||||
__host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; }
|
||||
|
||||
__host__ void operator()(const ck_tile::stream_config& s) const
|
||||
{
|
||||
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
|
||||
}
|
||||
|
||||
struct kernel
|
||||
{
|
||||
__device__ static constexpr auto get_src_dist()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
constexpr index_t K2 = Alignment;
|
||||
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
|
||||
|
||||
static_assert(NPerBlock % (N1 * N2) == 0);
|
||||
static_assert(KPerBlock % (K1 * K2) == 0);
|
||||
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2);
|
||||
constexpr index_t N0 = NPerBlock / (N1 * N2);
|
||||
|
||||
// clang-format off
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0
|
||||
// 1 2 3 4 5 6
|
||||
tuple<sequence<N0>, sequence<N1>, sequence<N2>, sequence<K0>, sequence<K1>, sequence<K2>>,
|
||||
|
||||
// N1 K1 N2
|
||||
tuple<sequence<2>, sequence<5, 3>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
|
||||
// N0 K0 K2
|
||||
sequence<1, 4, 6>,
|
||||
sequence<0, 0, 0>>{});
|
||||
// clang-format on
|
||||
}
|
||||
__device__ static constexpr auto get_dst_dist()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
constexpr index_t K2 = Alignment;
|
||||
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
|
||||
|
||||
static_assert(NPerBlock % (N1 * N2) == 0);
|
||||
static_assert(KPerBlock % (K1 * K2) == 0);
|
||||
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2);
|
||||
constexpr index_t N0 = NPerBlock / (N1 * N2);
|
||||
|
||||
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
|
||||
{
|
||||
// clang-format off
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0
|
||||
// 1 2 3 4 5 6
|
||||
tuple<sequence<N0>, sequence<K0>, sequence<N1>, sequence<K1>, sequence<N2>, sequence<K2>>,
|
||||
|
||||
// N1 K1 N2
|
||||
tuple<sequence<3>, sequence<4, 5>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
|
||||
// N0 K0 K2
|
||||
sequence<1, 2, 6>,
|
||||
sequence<0, 0, 0>>{});
|
||||
// clang-format on
|
||||
}
|
||||
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
|
||||
{
|
||||
// clang-format off
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0
|
||||
// 1 2 3 4 5 6
|
||||
tuple<sequence<N0>, sequence<N1>, sequence<K0>, sequence<K1>, sequence<N2>, sequence<K2>>,
|
||||
|
||||
// N1 K1 N2
|
||||
tuple<sequence<2>, sequence<4, 5>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
|
||||
// N0 K0 K2
|
||||
sequence<1, 3, 6>,
|
||||
sequence<0, 0, 0>>{});
|
||||
// clang-format on
|
||||
}
|
||||
else
|
||||
{
|
||||
// clang-format off
|
||||
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
|
||||
constexpr index_t Kv = Alignment;
|
||||
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
|
||||
static_assert(KPerBlock % (K1 * K2) == 0);
|
||||
constexpr index_t Nr = NPerBlock / Nw;
|
||||
constexpr index_t Kr = KPerBlock / (Kv * Kw);
|
||||
|
||||
constexpr index_t Nr_p = WavesPerBlock_N;
|
||||
constexpr index_t Kr_p = WavesPerBlock_K;
|
||||
constexpr index_t Nr_y = Nr / Nr_p;
|
||||
constexpr index_t Kr_y = Kr / Kr_p;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
#if MERGE_2D_013425
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0 R
|
||||
// major 1 2
|
||||
// minor 0 1 2 0 1 2 3
|
||||
tuple<sequence<Nr_y, Nr_p, Nw>, sequence<Kr_y, Kr_p, Kw, Kv>>, // H
|
||||
|
||||
// Nr_p, Kr_p Kw Nw
|
||||
tuple<sequence<1 , 2>, sequence<2, 1>>, // p major
|
||||
tuple<sequence<1 , 1>, sequence<2, 2>>, // p minor
|
||||
|
||||
// Nr_y Kr_y Kv
|
||||
sequence<1, 2, 2>, // Y major
|
||||
sequence<0, 0, 3>>{}); // y minor
|
||||
#else
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0 R
|
||||
// major 1 2 3
|
||||
// minor 0 1 0 1 0 1 2
|
||||
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>, // H
|
||||
|
||||
// Nr_p, Kr_p Kw Nw
|
||||
tuple<sequence<1 , 2>, sequence<3, 3>>, // p major
|
||||
tuple<sequence<1 , 1>, sequence<0, 1>>, // p minor
|
||||
|
||||
// Nr_y Kr_y Kv
|
||||
sequence<1, 2, 3>, // Y major
|
||||
sequence<0, 0, 2>>{}); // y minor
|
||||
#endif
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void operator()(karg a_)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
index_t i_k = blockIdx.x;
|
||||
index_t i_n = blockIdx.y;
|
||||
index_t i_b = blockIdx.z;
|
||||
|
||||
constexpr index_t k2 = Alignment;
|
||||
constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t n1 = BLOCK_SIZE / get_warp_size();
|
||||
const index_t k0 = a_.k / (k1 * k2);
|
||||
const index_t n0 = a_.n / (n1 * n2);
|
||||
|
||||
constexpr index_t k2_tile = Alignment;
|
||||
constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size();
|
||||
constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile);
|
||||
constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile);
|
||||
|
||||
const fp16_t* p_src = reinterpret_cast<const fp16_t*>(a_.p_src) + i_b * a_.k * a_.n;
|
||||
fp16_t* p_dst = reinterpret_cast<fp16_t*>(a_.p_dst) + i_b * a_.k * a_.n;
|
||||
|
||||
const auto src_view = [&]() {
|
||||
const auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_src,
|
||||
make_tuple(n0, n1, n2, k0, k1, k2),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
const auto src_window = make_tile_window(src_view,
|
||||
make_tuple(number<n0_tile>{},
|
||||
number<n1_tile>{},
|
||||
number<n2_tile>{},
|
||||
number<k0_tile>{},
|
||||
number<k1_tile>{},
|
||||
number<k2_tile>{}),
|
||||
{i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0},
|
||||
get_src_dist());
|
||||
|
||||
auto dst_view = [&]() {
|
||||
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
|
||||
{
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(n0, k0, n1, k1, n2, k2),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
}
|
||||
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
|
||||
{
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(n0, n1, k0, k1, n2, k2),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
}
|
||||
else
|
||||
{
|
||||
#if MERGE_2D_013425
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
// constexpr index_t waveflatten = kw*nw*kv;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(nr, kr, number<kw>{}, number<nw>{}, number<kv>{}),
|
||||
number<Alignment>{}); // control vector load
|
||||
auto tmp_1 = transform_tensor_view(
|
||||
tmp,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(nr, number<nw>{})),
|
||||
make_merge_transform(make_tuple(kr, number<kw>{}, number<kv>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return tmp_1;
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten = kw * nw * kv;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(nr, kr, waveflatten),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
}();
|
||||
|
||||
auto dst_window = [&]() {
|
||||
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
|
||||
{
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<n0_tile>{},
|
||||
number<k0_tile>{},
|
||||
number<n1_tile>{},
|
||||
number<k1_tile>{},
|
||||
number<n2_tile>{},
|
||||
number<k2_tile>{}),
|
||||
{i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0},
|
||||
get_dst_dist());
|
||||
}
|
||||
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
|
||||
{
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<n0_tile>{},
|
||||
number<n1_tile>{},
|
||||
number<k0_tile>{},
|
||||
number<k1_tile>{},
|
||||
number<n2_tile>{},
|
||||
number<k2_tile>{}),
|
||||
{i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0},
|
||||
get_dst_dist());
|
||||
}
|
||||
else
|
||||
{
|
||||
#if MERGE_2D_013425
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{i_n * NPerBlock, i_k * KPerBlock},
|
||||
get_dst_dist());
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten_tile = kw * nw * kv;
|
||||
constexpr index_t nr_tile = NPerBlock / nw;
|
||||
constexpr index_t kr_tile = KPerBlock / (kw * kv);
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<nr_tile>{},
|
||||
number<kr_tile>{},
|
||||
number<waveflatten_tile>{}),
|
||||
{i_n * nr_tile, i_k * kr_tile, 0},
|
||||
get_dst_dist());
|
||||
#endif
|
||||
}
|
||||
}();
|
||||
|
||||
// actual load store
|
||||
auto src_tile = load_tile(src_window);
|
||||
|
||||
// now we only swap the distribution from src to dst, no extra movement occurs
|
||||
auto dst_tile = make_static_distributed_tensor<fp16_t>(get_dst_dist());
|
||||
dst_tile.get_thread_buffer() = src_tile.get_thread_buffer();
|
||||
|
||||
// final store
|
||||
store_tile(dst_window, dst_tile);
|
||||
}
|
||||
};
|
||||
};
|
||||
411
example/ck_tile/06_permute/permute.cpp
Normal file
411
example/ck_tile/06_permute/permute.cpp
Normal file
@@ -0,0 +1,411 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "permute.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
|
||||
#include "alternative_impl/matrix_core_swizzle.hpp"
|
||||
#endif
|
||||
|
||||
namespace detail {
|
||||
template <int bytes>
|
||||
struct to_integer_type;
|
||||
|
||||
template <>
|
||||
struct to_integer_type<4>
|
||||
{
|
||||
using type = int32_t;
|
||||
};
|
||||
template <>
|
||||
struct to_integer_type<2>
|
||||
{
|
||||
using type = int16_t;
|
||||
};
|
||||
template <>
|
||||
struct to_integer_type<1>
|
||||
{
|
||||
using type = int8_t;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int bytes>
|
||||
using to_integer_type = typename detail::to_integer_type<bytes>::type;
|
||||
|
||||
// host API (shoule come from codegen)
|
||||
float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp8") == 0)
|
||||
{
|
||||
using DataType = ck_tile::fp8_t;
|
||||
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
|
||||
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
using DataType = ck_tile::half_t;
|
||||
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
|
||||
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.data_type.compare("fp32") == 0)
|
||||
{
|
||||
using DataType = float;
|
||||
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
|
||||
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)")
|
||||
.insert("shape", "2,3,4", "the shape of the input tensor")
|
||||
.insert("perm", "2,1,0", "permute perm")
|
||||
.insert("kname", "0", "t to 1 will print kernel name")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
}
|
||||
|
||||
// "1,2,3,4" -> vector{1,2,3,4}
|
||||
std::vector<ck_tile::index_t> decode_vec(std::string q_val)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
std::string::size_type pos = 0;
|
||||
std::vector<ck_tile::index_t> v;
|
||||
while(true)
|
||||
{
|
||||
auto found = q_val.find(',', pos);
|
||||
ck_tile::index_t n =
|
||||
_S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos));
|
||||
v.push_back(n);
|
||||
if(found == std::string::npos)
|
||||
{
|
||||
break;
|
||||
}
|
||||
pos = found + 1;
|
||||
}
|
||||
return v;
|
||||
#undef _S2I_
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
|
||||
auto shape = decode_vec(arg_parser.get_str("shape"));
|
||||
auto perm = decode_vec(arg_parser.get_str("perm"));
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
int seed = arg_parser.get_int("seed");
|
||||
|
||||
assert(shape.size() == perm.size());
|
||||
ck_tile::index_t rank = perm.size();
|
||||
if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks)
|
||||
{
|
||||
printf("rank %d permute is not support yet\n", rank);
|
||||
return false;
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<DataType> x(shape);
|
||||
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
|
||||
|
||||
std::vector<ck_tile::index_t> y_shape = [&]() {
|
||||
std::vector<ck_tile::index_t> tmp(rank, 0);
|
||||
// std::cout << "@@@@" << tmp << std::endl;
|
||||
for(int i = 0; i < static_cast<int>(rank); i++)
|
||||
{
|
||||
// std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" <<
|
||||
// static_cast<int>(rank)
|
||||
// << std::endl;
|
||||
tmp[i] = shape[perm[i]];
|
||||
}
|
||||
// std::cout << "@@@" << tmp << std::endl;
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
ck_tile::HostTensor<DataType> y(y_shape);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x.data());
|
||||
|
||||
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm
|
||||
<< std::flush;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
stream_warmup,
|
||||
stream_repeat};
|
||||
float ave_time = 0.f;
|
||||
auto run_permute = [&]() {
|
||||
permute_traits t;
|
||||
t.data_type = data_type;
|
||||
|
||||
permute_args a;
|
||||
a.p_src = x_buf.GetDeviceBuffer();
|
||||
a.p_dst = y_buf.GetDeviceBuffer();
|
||||
a.rank = rank;
|
||||
std::copy(shape.begin(), shape.end(), a.shape);
|
||||
std::copy(perm.begin(), perm.end(), a.perm);
|
||||
|
||||
return permute(t, a, stream_config);
|
||||
};
|
||||
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
|
||||
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
|
||||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") ||
|
||||
arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")))
|
||||
{
|
||||
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
|
||||
{
|
||||
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
matrix_core_swizzle_traits t;
|
||||
t.data_type = data_type;
|
||||
t.permute = arg_parser.get_str("perm");
|
||||
|
||||
matrix_core_swizzle_args a;
|
||||
a.p_src = x_buf.GetDeviceBuffer();
|
||||
a.p_dst = y_buf.GetDeviceBuffer();
|
||||
a.batch = shape[0];
|
||||
|
||||
auto nr = shape[1];
|
||||
auto nw = shape[2];
|
||||
auto kr = shape[3];
|
||||
auto kw = shape[4];
|
||||
auto kv = shape[5];
|
||||
a.n = nr * nw;
|
||||
a.k = kr * kw * kv;
|
||||
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
|
||||
{
|
||||
t.inst = "16x16x16";
|
||||
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
|
||||
{
|
||||
t.inst = "32x32x8";
|
||||
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = run_permute();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
matrix_core_swizzle_traits t;
|
||||
t.data_type = data_type;
|
||||
t.permute = arg_parser.get_str("perm");
|
||||
|
||||
matrix_core_swizzle_args a;
|
||||
a.p_src = x_buf.GetDeviceBuffer();
|
||||
a.p_dst = y_buf.GetDeviceBuffer();
|
||||
a.batch = shape[0];
|
||||
a.n = shape[1] * shape[2] * shape[3];
|
||||
a.k = shape[4] * shape[5] * shape[6];
|
||||
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
|
||||
shape[4] % 8 == 0 && shape[1] % 2 == 0)
|
||||
{
|
||||
// 32x32x8 inst
|
||||
// perm=0,1,4,2,5,3,6
|
||||
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
|
||||
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
|
||||
|
||||
t.inst = "32x32x8";
|
||||
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
|
||||
shape[4] % 4 == 0 && shape[1] % 4 == 0)
|
||||
{
|
||||
// 16x16x16 inst
|
||||
// perm=0,1,4,2,5,3,6
|
||||
// y_shape=*,4x,4x,4,4,16,8
|
||||
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
|
||||
t.inst = "16x16x16";
|
||||
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = run_permute();
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
ave_time = run_permute();
|
||||
}
|
||||
std::cout << ", time:" << ave_time << "ms" << std::flush;
|
||||
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
reference_permute(x, y, perm);
|
||||
#if 0
|
||||
if constexpr (std::is_same_v<float, DataType>){
|
||||
// using itype = to_integer_type<sizeof(DataType)>;
|
||||
fflush(stdout);
|
||||
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
|
||||
printf("%3.0f ", x.mData[zz]);
|
||||
}
|
||||
printf("->\n");
|
||||
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
|
||||
printf("%3.0f ", y.mData[zz]);
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
#endif
|
||||
ck_tile::HostTensor<DataType> y_dev(y.get_lengths());
|
||||
|
||||
y_buf.FromDevice(y_dev.data());
|
||||
|
||||
pass = std::equal(
|
||||
y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) {
|
||||
using itype = to_integer_type<sizeof(DataType)>;
|
||||
itype i_d = ck_tile::bit_cast<itype>(d);
|
||||
itype i_h = ck_tile::bit_cast<itype>(h);
|
||||
return i_d == i_h;
|
||||
});
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp32")
|
||||
{
|
||||
return run<float>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
19
example/ck_tile/06_permute/permute.hpp
Normal file
19
example/ck_tile/06_permute/permute.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/permute.hpp"
|
||||
#include <string>
|
||||
|
||||
struct permute_traits
|
||||
{
|
||||
std::string data_type;
|
||||
};
|
||||
|
||||
using permute_args = ck_tile::GenericPermuteHostArgs;
|
||||
|
||||
// host API
|
||||
float permute(permute_traits, permute_args, const ck_tile::stream_config&);
|
||||
34
example/ck_tile/06_permute/script/smoke_test.sh
Normal file
34
example/ck_tile/06_permute/script/smoke_test.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_permute
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
# mode=0
|
||||
# export HIP_VISIBLE_DEVICES=4
|
||||
if [ $# -ge 1 ] ; then
|
||||
set -x
|
||||
fi
|
||||
|
||||
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=2,8,16,8,4,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=1,24,32,16,2,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
|
||||
|
||||
echo "------------------------------------------------------------------"
|
||||
|
||||
for prec in "fp8" "fp16" "fp32" ; do
|
||||
|
||||
$EXE -prec=$prec -shape=3,8 -perm=1,0 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=48,6,8 -perm=2,1,0 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=24,128,3 -perm=0,2,1 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=4,10,7,6 -perm=0,2,3,1 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=8,24,36,10 -perm=3,1,2,0 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=8,1,36,4 -perm=2,1,0,3 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=5,10,16,2,36,4 -perm=4,5,2,1,0,3 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=2,32,8,3,6,2,5,4 -perm=5,2,4,7,1,6,3,0 $COMMON_ARGS
|
||||
echo "------------------------------------------------------------------"
|
||||
done
|
||||
@@ -7,5 +7,6 @@ add_subdirectory(02_layernorm2d)
|
||||
add_subdirectory(03_gemm)
|
||||
add_subdirectory(04_img2col)
|
||||
add_subdirectory(05_reduce)
|
||||
add_subdirectory(06_permute)
|
||||
add_subdirectory(09_topk_softmax)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_permute.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
|
||||
57
include/ck_tile/host/reference/reference_permute.hpp
Normal file
57
include/ck_tile/host/reference/reference_permute.hpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
this will do permute + contiguous like functionality in pytorch
|
||||
*/
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST void
|
||||
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
|
||||
{
|
||||
const auto x_len = x.mDesc.get_lengths();
|
||||
const auto y_len = y.mDesc.get_lengths();
|
||||
assert(x_len.size() == y_len.size());
|
||||
index_t rank = x_len.size();
|
||||
const auto x_elm = std::accumulate(x_len.begin(), x_len.end(), 1, std::multiplies<index_t>());
|
||||
const auto y_elm = std::accumulate(y_len.begin(), y_len.end(), 1, std::multiplies<index_t>());
|
||||
assert(x_elm == y_elm);
|
||||
(void)y_elm;
|
||||
|
||||
auto f = [&](auto i_element) {
|
||||
std::vector<size_t> y_coord = [&]() {
|
||||
std::vector<size_t> tmp(rank, 0);
|
||||
size_t r = i_element;
|
||||
for(index_t i = rank - 1; i >= 0; i--)
|
||||
{
|
||||
tmp[i] = r % y_len[i];
|
||||
r = r / y_len[i];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
std::vector<size_t> x_coord = [&]() {
|
||||
std::vector<size_t> tmp(rank, 0);
|
||||
for(index_t i = 0; i < rank; i++)
|
||||
{
|
||||
tmp[dims[i]] = y_coord[i];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
// do permute
|
||||
y(y_coord) = x(x_coord);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
8
include/ck_tile/ops/permute.hpp
Normal file
8
include/ck_tile/ops/permute.hpp
Normal file
@@ -0,0 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
|
||||
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
169
include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
Normal file
169
include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
// #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/* independent host side argument, no template
|
||||
*/
|
||||
struct GenericPermuteHostArgs
|
||||
{
|
||||
static constexpr index_t kMaxRanks = 8; // TODO: hardcoded
|
||||
|
||||
const void* p_src;
|
||||
void* p_dst;
|
||||
index_t rank;
|
||||
index_t shape[kMaxRanks]; // input shape
|
||||
index_t perm[kMaxRanks]; // permute index
|
||||
};
|
||||
|
||||
/*
|
||||
simulate torch.permute:
|
||||
x_ = x_.view(x.shape[0],
|
||||
x.shape[1]//16, 16,
|
||||
x.shape[2]//32, 4, 8)
|
||||
x_ = x_.permute(0,1,3,4,2,5)
|
||||
x_ = x_.contiguous()
|
||||
x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);//
|
||||
|
||||
this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks
|
||||
dim of permutation, with a single kernel
|
||||
|
||||
*/
|
||||
template <typename Problem_>
|
||||
struct GenericPermute
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
|
||||
using DataType = remove_cvref_t<typename Problem::DataType>;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMaxRanks = Problem::kMaxRanks;
|
||||
static constexpr bool KeepLastDim = Problem::KeepLastDim;
|
||||
|
||||
struct __attribute__((packed)) Kargs
|
||||
{
|
||||
const void* p_src;
|
||||
void* p_dst;
|
||||
// index_t rank;
|
||||
index_t num_elements;
|
||||
index_t perm_length[kMaxRanks]; // tensor length after permutation
|
||||
index_t perm_stride[kMaxRanks]; // tensor stride after permutation
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr index_t TotalElements(const GenericPermuteHostArgs& h)
|
||||
{
|
||||
index_t n = 1;
|
||||
for(auto i = 0; i < h.rank; i++)
|
||||
{
|
||||
n *= h.shape[i];
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const GenericPermuteHostArgs& h)
|
||||
{
|
||||
Kargs a;
|
||||
a.p_src = h.p_src;
|
||||
a.p_dst = h.p_dst;
|
||||
|
||||
// assert rank <= kMaxRanks
|
||||
index_t i = 0;
|
||||
|
||||
index_t perm[kMaxRanks];
|
||||
index_t x_shape[kMaxRanks];
|
||||
index_t x_stride[kMaxRanks];
|
||||
// index_t perm_length[kMaxRanks];
|
||||
|
||||
for(; i < h.rank; i++)
|
||||
{
|
||||
x_shape[i] = h.shape[i];
|
||||
perm[i] = h.perm[i];
|
||||
}
|
||||
for(; i < kMaxRanks; i++)
|
||||
{
|
||||
x_shape[i] = 1;
|
||||
perm[i] = i; // will index to len = 1
|
||||
}
|
||||
|
||||
index_t stride = 1;
|
||||
for(index_t j = kMaxRanks - 1; j >= 0; j--)
|
||||
{
|
||||
x_stride[j] = stride;
|
||||
stride *= x_shape[j];
|
||||
}
|
||||
|
||||
for(index_t j = 0; j < kMaxRanks; j++)
|
||||
{
|
||||
a.perm_length[j] = x_shape[perm[j]];
|
||||
a.perm_stride[j] = x_stride[perm[j]];
|
||||
}
|
||||
|
||||
a.num_elements = TotalElements(h);
|
||||
return a;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(GenericPermuteHostArgs h)
|
||||
{
|
||||
auto total = TotalElements(h);
|
||||
auto grids = dim3((total + BlockSize() - 1) / BlockSize());
|
||||
// printf("### total:%d, grids:%dx%dx%d\n", total, );
|
||||
return grids;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t id = blockIdx.x * BlockSize() + threadIdx.x;
|
||||
|
||||
if(id >= kargs.num_elements)
|
||||
return;
|
||||
|
||||
const auto perm_length =
|
||||
generate_tuple([&](auto I) { return kargs.perm_length[I]; }, number<kMaxRanks>{});
|
||||
const auto perm_stride =
|
||||
generate_tuple([&](auto I) { return kargs.perm_stride[I]; }, number<kMaxRanks>{});
|
||||
|
||||
const DataType* p_src = reinterpret_cast<const DataType*>(kargs.p_src);
|
||||
DataType* p_dst = reinterpret_cast<DataType*>(kargs.p_dst);
|
||||
|
||||
const auto src_view_0 = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_src, perm_length, perm_stride, number<1>{}, number<1>{});
|
||||
|
||||
const auto src_view = transform_tensor_view(
|
||||
src_view_0,
|
||||
make_tuple(make_merge_transform(perm_length)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
auto dst_view_0 = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst, perm_length, number<1>{});
|
||||
|
||||
auto dst_view = transform_tensor_view(
|
||||
dst_view_0,
|
||||
make_tuple(make_merge_transform(perm_length)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
|
||||
// TODO: hard code to vector 1
|
||||
using vector_t = thread_buffer<DataType, 1>;
|
||||
|
||||
const auto src_coord =
|
||||
make_tensor_coordinate(src_view.get_tensor_descriptor(), array<index_t, 1>{id});
|
||||
const auto dst_coord =
|
||||
make_tensor_coordinate(dst_view.get_tensor_descriptor(), array<index_t, 1>{id});
|
||||
|
||||
// printf("src id:%d, os:%d\n", id, src_coord.get_offset());
|
||||
// printf("dst id:%d, os:%d\n", id, dst_coord.get_offset());
|
||||
|
||||
const vector_t x = src_view.template get_vectorized_elements<vector_t>(src_coord, 0);
|
||||
dst_view.template set_vectorized_elements<vector_t>(dst_coord, 0, x);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,28 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_,
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kMaxRanks_ = 8,
|
||||
bool KeepLastDim_ = false>
|
||||
struct GenericPermuteProblem
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMaxRanks = kMaxRanks_;
|
||||
/* KeepLastDim:
|
||||
* if last dim keep the same? this can help enable vector load
|
||||
* permute(0, 2, 4, 1, 3, 5) -> true
|
||||
* permute(0, 3, 2, 1) -> false
|
||||
*/
|
||||
static constexpr bool KeepLastDim = KeepLastDim_;
|
||||
// TODO: not used(?)
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user