mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] add generic_permute (#1607)
This commit is contained in:
13
example/ck_tile/06_permute/CMakeLists.txt
Normal file
13
example/ck_tile/06_permute/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp)
|
||||
|
||||
if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL)
|
||||
# set(PERMUTE_USE_ALTERNATIVE_IMPL false)
|
||||
set(PERMUTE_USE_ALTERNATIVE_IMPL true)
|
||||
endif()
|
||||
if(PERMUTE_USE_ALTERNATIVE_IMPL)
|
||||
target_compile_options(tile_example_permute PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL)
|
||||
target_sources(tile_example_permute PRIVATE alternative_impl/matrix_core_swizzle.cpp)
|
||||
endif()
|
||||
# target_compile_options(tile_example_permute PRIVATE -v --save-temps -Wno-gnu-line-marker)
|
||||
46
example/ck_tile/06_permute/README.md
Normal file
46
example/ck_tile/06_permute/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# permute
|
||||
|
||||
This folder contains example for permute kernel, which is similiar to [torch.permute](https://pytorch.org/docs/stable/generated/torch.permute.html) (combined with [torch.contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html)). Currently we implement a generic permute kernel that support up to rank 8 arbitrary permutation with a single kernel instance. Performance is not the first consideration, we prefer a simple and general kernel implementation using `ck_tile` in this example.
|
||||
|
||||
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-prec data type. fp16/bf16/fp32 (default:fp16)
|
||||
-shape the shape of the input tensor (default:2,3,4)
|
||||
-perm permute perm (default:2,1,0)
|
||||
```
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_permute -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_permute`
|
||||
|
||||
|
||||
## some examples
|
||||
```
|
||||
# torch
|
||||
x=torch.randn(2,3,4,6)
|
||||
y=x.permute(0,3,2,1).contiguous()
|
||||
|
||||
# ck_tile
|
||||
./build/bin/tile_example_permute -shape=2,3,4,6 -perm=0,3,2,1
|
||||
```
|
||||
|
||||
or you can try the smoke_test
|
||||
```
|
||||
# in the root of ck_tile, after you build this example
|
||||
sh example/ck_tile/06_permute/script/smoke_test.sh
|
||||
```
|
||||
|
||||
### alternative implementation
|
||||
we have an alternative implementation under `alternative_impl/` folder, that can swizzle the tensor to be more friendly for data loading for matrix core layout. This can be enabled when dealing with a `rank-7` tensor, with a fixed pattern of either `0,1,4,2,5,3,6` or `0,1,2,4,5,3,6`. There are other shape limitation of this implementation, check the source code of `permute.cpp` for detail.
|
||||
```
|
||||
# example
|
||||
./build/bin/tile_example_permute -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 # b_n0_k0_n1_k1_n2_k2
|
||||
./build/bin/tile_example_permute -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 # b_n0_n1_k0_k1_n2_k2
|
||||
```
|
||||
@@ -0,0 +1,98 @@
|
||||
#include "matrix_core_swizzle.hpp"
|
||||
#include "matrix_core_swizzle_kernel.hpp"
|
||||
|
||||
float matrix_core_swizzle(matrix_core_swizzle_traits t,
|
||||
matrix_core_swizzle_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
if(t.inst.compare("32x32x8") == 0)
|
||||
{
|
||||
constexpr int BLOCK_SIZE = 256;
|
||||
constexpr int NPerBlock = 256;
|
||||
constexpr int KPerBlock = 128;
|
||||
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
|
||||
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
}
|
||||
else if(t.inst.compare("16x16x16") == 0)
|
||||
{
|
||||
constexpr int BLOCK_SIZE = 256;
|
||||
constexpr int NPerBlock = 256;
|
||||
constexpr int KPerBlock = 128;
|
||||
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
|
||||
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
auto k = Kernel(a);
|
||||
float ave_time = ck_tile::launch_kernel(s, k);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "matrix_core_swizzle_kernel.hpp"
|
||||
#include <string>
|
||||
|
||||
struct matrix_core_swizzle_traits
|
||||
{
|
||||
std::string data_type; // fp16 only
|
||||
std::string inst; // 32x32x8, 16x16x16
|
||||
std::string permute; //
|
||||
};
|
||||
|
||||
using matrix_core_swizzle_args = matrix_core_swizzle_host_args;
|
||||
|
||||
// host API
|
||||
float matrix_core_swizzle(matrix_core_swizzle_traits,
|
||||
matrix_core_swizzle_args,
|
||||
const ck_tile::stream_config&);
|
||||
@@ -0,0 +1,413 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
// if set to 1, slightly more instructions generated to calculate address
|
||||
#ifndef MERGE_2D_013425
|
||||
#define MERGE_2D_013425 0
|
||||
#endif
|
||||
|
||||
enum class matrix_core_inst_enum
|
||||
{
|
||||
MFMA_32x32x8_F16 = 0,
|
||||
MFMA_16x16x16_F16 = 1,
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
template <matrix_core_inst_enum>
|
||||
struct to_warp_gemm;
|
||||
|
||||
template <>
|
||||
struct to_warp_gemm<matrix_core_inst_enum::MFMA_32x32x8_F16>
|
||||
{
|
||||
using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
|
||||
{
|
||||
using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16;
|
||||
};
|
||||
} // namespace detail
|
||||
template <matrix_core_inst_enum Inst>
|
||||
using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type;
|
||||
|
||||
// TODO: in below permute pattern, the last 3 dim is within wave
|
||||
enum class matrix_core_permute_style
|
||||
{
|
||||
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
|
||||
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
|
||||
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
};
|
||||
|
||||
// assume this is B matrix, originally we have batch*n*k
|
||||
// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
|
||||
// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4)
|
||||
//
|
||||
// 4(waves) 32(mfma_m lane)
|
||||
// | |
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading)
|
||||
// nr kr |
|
||||
// nr 4 32 kr 2 8 2(klane)
|
||||
//
|
||||
// permute: 0,1,4,2,5,3,6
|
||||
// or
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading)
|
||||
// permute: 0,1,2,4,5,3,6
|
||||
//
|
||||
// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling
|
||||
// for simplicity, only consider n/k is multiple of block-size
|
||||
|
||||
// independend host arg with no template
|
||||
struct matrix_core_swizzle_host_args
|
||||
{
|
||||
const void* p_src;
|
||||
void* p_dst;
|
||||
int32_t batch;
|
||||
int32_t n;
|
||||
int32_t k;
|
||||
};
|
||||
|
||||
// NOTE: this kernel could follow the style of generic permute kernel
|
||||
// but here we pass in fixed layout as template arg and generate different kernel instance
|
||||
// purposely
|
||||
template <int BLOCK_SIZE_ = 256,
|
||||
int NPerBlock_ = 256,
|
||||
int KPerBlock_ = 128,
|
||||
matrix_core_permute_style pstyle_ =
|
||||
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2,
|
||||
matrix_core_inst_enum Inst_ = matrix_core_inst_enum::MFMA_32x32x8_F16>
|
||||
struct matrix_core_swizzle_kernel
|
||||
{
|
||||
using karg = matrix_core_swizzle_host_args;
|
||||
using harg = matrix_core_swizzle_host_args;
|
||||
|
||||
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
|
||||
static constexpr int WavesPerBlock_N = 4;
|
||||
static constexpr int WavesPerBlock_K = 1;
|
||||
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
|
||||
static constexpr int NPerBlock = NPerBlock_;
|
||||
static constexpr int KPerBlock = KPerBlock_;
|
||||
static constexpr matrix_core_permute_style pstyle = pstyle_;
|
||||
static constexpr matrix_core_inst_enum Inst = Inst_;
|
||||
|
||||
static constexpr ck_tile::index_t Alignment = 8;
|
||||
karg a;
|
||||
dim3 grids;
|
||||
|
||||
using WarpGemm = to_warp_gemm_t<Inst>;
|
||||
|
||||
__host__ matrix_core_swizzle_kernel(harg h)
|
||||
{
|
||||
a = h;
|
||||
ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock;
|
||||
ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock;
|
||||
grids = dim3(ks, ns, h.batch);
|
||||
}
|
||||
|
||||
__host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; }
|
||||
|
||||
__host__ void operator()(const ck_tile::stream_config& s) const
|
||||
{
|
||||
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
|
||||
}
|
||||
|
||||
struct kernel
|
||||
{
|
||||
__device__ static constexpr auto get_src_dist()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
constexpr index_t K2 = Alignment;
|
||||
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
|
||||
|
||||
static_assert(NPerBlock % (N1 * N2) == 0);
|
||||
static_assert(KPerBlock % (K1 * K2) == 0);
|
||||
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2);
|
||||
constexpr index_t N0 = NPerBlock / (N1 * N2);
|
||||
|
||||
// clang-format off
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0
|
||||
// 1 2 3 4 5 6
|
||||
tuple<sequence<N0>, sequence<N1>, sequence<N2>, sequence<K0>, sequence<K1>, sequence<K2>>,
|
||||
|
||||
// N1 K1 N2
|
||||
tuple<sequence<2>, sequence<5, 3>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
|
||||
// N0 K0 K2
|
||||
sequence<1, 4, 6>,
|
||||
sequence<0, 0, 0>>{});
|
||||
// clang-format on
|
||||
}
|
||||
__device__ static constexpr auto get_dst_dist()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
constexpr index_t K2 = Alignment;
|
||||
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
|
||||
|
||||
static_assert(NPerBlock % (N1 * N2) == 0);
|
||||
static_assert(KPerBlock % (K1 * K2) == 0);
|
||||
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2);
|
||||
constexpr index_t N0 = NPerBlock / (N1 * N2);
|
||||
|
||||
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
|
||||
{
|
||||
// clang-format off
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0
|
||||
// 1 2 3 4 5 6
|
||||
tuple<sequence<N0>, sequence<K0>, sequence<N1>, sequence<K1>, sequence<N2>, sequence<K2>>,
|
||||
|
||||
// N1 K1 N2
|
||||
tuple<sequence<3>, sequence<4, 5>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
|
||||
// N0 K0 K2
|
||||
sequence<1, 2, 6>,
|
||||
sequence<0, 0, 0>>{});
|
||||
// clang-format on
|
||||
}
|
||||
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
|
||||
{
|
||||
// clang-format off
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0
|
||||
// 1 2 3 4 5 6
|
||||
tuple<sequence<N0>, sequence<N1>, sequence<K0>, sequence<K1>, sequence<N2>, sequence<K2>>,
|
||||
|
||||
// N1 K1 N2
|
||||
tuple<sequence<2>, sequence<4, 5>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
|
||||
// N0 K0 K2
|
||||
sequence<1, 3, 6>,
|
||||
sequence<0, 0, 0>>{});
|
||||
// clang-format on
|
||||
}
|
||||
else
|
||||
{
|
||||
// clang-format off
|
||||
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
|
||||
constexpr index_t Kv = Alignment;
|
||||
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
|
||||
static_assert(KPerBlock % (K1 * K2) == 0);
|
||||
constexpr index_t Nr = NPerBlock / Nw;
|
||||
constexpr index_t Kr = KPerBlock / (Kv * Kw);
|
||||
|
||||
constexpr index_t Nr_p = WavesPerBlock_N;
|
||||
constexpr index_t Kr_p = WavesPerBlock_K;
|
||||
constexpr index_t Nr_y = Nr / Nr_p;
|
||||
constexpr index_t Kr_y = Kr / Kr_p;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
#if MERGE_2D_013425
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0 R
|
||||
// major 1 2
|
||||
// minor 0 1 2 0 1 2 3
|
||||
tuple<sequence<Nr_y, Nr_p, Nw>, sequence<Kr_y, Kr_p, Kw, Kv>>, // H
|
||||
|
||||
// Nr_p, Kr_p Kw Nw
|
||||
tuple<sequence<1 , 2>, sequence<2, 1>>, // p major
|
||||
tuple<sequence<1 , 1>, sequence<2, 2>>, // p minor
|
||||
|
||||
// Nr_y Kr_y Kv
|
||||
sequence<1, 2, 2>, // Y major
|
||||
sequence<0, 0, 3>>{}); // y minor
|
||||
#else
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,// 0 R
|
||||
// major 1 2 3
|
||||
// minor 0 1 0 1 0 1 2
|
||||
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>, // H
|
||||
|
||||
// Nr_p, Kr_p Kw Nw
|
||||
tuple<sequence<1 , 2>, sequence<3, 3>>, // p major
|
||||
tuple<sequence<1 , 1>, sequence<0, 1>>, // p minor
|
||||
|
||||
// Nr_y Kr_y Kv
|
||||
sequence<1, 2, 3>, // Y major
|
||||
sequence<0, 0, 2>>{}); // y minor
|
||||
#endif
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void operator()(karg a_)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
index_t i_k = blockIdx.x;
|
||||
index_t i_n = blockIdx.y;
|
||||
index_t i_b = blockIdx.z;
|
||||
|
||||
constexpr index_t k2 = Alignment;
|
||||
constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t n1 = BLOCK_SIZE / get_warp_size();
|
||||
const index_t k0 = a_.k / (k1 * k2);
|
||||
const index_t n0 = a_.n / (n1 * n2);
|
||||
|
||||
constexpr index_t k2_tile = Alignment;
|
||||
constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size();
|
||||
constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile);
|
||||
constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile);
|
||||
|
||||
const fp16_t* p_src = reinterpret_cast<const fp16_t*>(a_.p_src) + i_b * a_.k * a_.n;
|
||||
fp16_t* p_dst = reinterpret_cast<fp16_t*>(a_.p_dst) + i_b * a_.k * a_.n;
|
||||
|
||||
const auto src_view = [&]() {
|
||||
const auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_src,
|
||||
make_tuple(n0, n1, n2, k0, k1, k2),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
const auto src_window = make_tile_window(src_view,
|
||||
make_tuple(number<n0_tile>{},
|
||||
number<n1_tile>{},
|
||||
number<n2_tile>{},
|
||||
number<k0_tile>{},
|
||||
number<k1_tile>{},
|
||||
number<k2_tile>{}),
|
||||
{i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0},
|
||||
get_src_dist());
|
||||
|
||||
auto dst_view = [&]() {
|
||||
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
|
||||
{
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(n0, k0, n1, k1, n2, k2),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
}
|
||||
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
|
||||
{
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(n0, n1, k0, k1, n2, k2),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
}
|
||||
else
|
||||
{
|
||||
#if MERGE_2D_013425
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
// constexpr index_t waveflatten = kw*nw*kv;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(nr, kr, number<kw>{}, number<nw>{}, number<kv>{}),
|
||||
number<Alignment>{}); // control vector load
|
||||
auto tmp_1 = transform_tensor_view(
|
||||
tmp,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(nr, number<nw>{})),
|
||||
make_merge_transform(make_tuple(kr, number<kw>{}, number<kv>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return tmp_1;
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten = kw * nw * kv;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(nr, kr, waveflatten),
|
||||
number<Alignment>{}); // control vector load
|
||||
return tmp;
|
||||
#endif
|
||||
}
|
||||
}();
|
||||
|
||||
auto dst_window = [&]() {
|
||||
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
|
||||
{
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<n0_tile>{},
|
||||
number<k0_tile>{},
|
||||
number<n1_tile>{},
|
||||
number<k1_tile>{},
|
||||
number<n2_tile>{},
|
||||
number<k2_tile>{}),
|
||||
{i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0},
|
||||
get_dst_dist());
|
||||
}
|
||||
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
|
||||
{
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<n0_tile>{},
|
||||
number<n1_tile>{},
|
||||
number<k0_tile>{},
|
||||
number<k1_tile>{},
|
||||
number<n2_tile>{},
|
||||
number<k2_tile>{}),
|
||||
{i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0},
|
||||
get_dst_dist());
|
||||
}
|
||||
else
|
||||
{
|
||||
#if MERGE_2D_013425
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{i_n * NPerBlock, i_k * KPerBlock},
|
||||
get_dst_dist());
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten_tile = kw * nw * kv;
|
||||
constexpr index_t nr_tile = NPerBlock / nw;
|
||||
constexpr index_t kr_tile = KPerBlock / (kw * kv);
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<nr_tile>{},
|
||||
number<kr_tile>{},
|
||||
number<waveflatten_tile>{}),
|
||||
{i_n * nr_tile, i_k * kr_tile, 0},
|
||||
get_dst_dist());
|
||||
#endif
|
||||
}
|
||||
}();
|
||||
|
||||
// actual load store
|
||||
auto src_tile = load_tile(src_window);
|
||||
|
||||
// now we only swap the distribution from src to dst, no extra movement occurs
|
||||
auto dst_tile = make_static_distributed_tensor<fp16_t>(get_dst_dist());
|
||||
dst_tile.get_thread_buffer() = src_tile.get_thread_buffer();
|
||||
|
||||
// final store
|
||||
store_tile(dst_window, dst_tile);
|
||||
}
|
||||
};
|
||||
};
|
||||
411
example/ck_tile/06_permute/permute.cpp
Normal file
411
example/ck_tile/06_permute/permute.cpp
Normal file
@@ -0,0 +1,411 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "permute.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
|
||||
#include "alternative_impl/matrix_core_swizzle.hpp"
|
||||
#endif
|
||||
|
||||
namespace detail {
|
||||
template <int bytes>
|
||||
struct to_integer_type;
|
||||
|
||||
template <>
|
||||
struct to_integer_type<4>
|
||||
{
|
||||
using type = int32_t;
|
||||
};
|
||||
template <>
|
||||
struct to_integer_type<2>
|
||||
{
|
||||
using type = int16_t;
|
||||
};
|
||||
template <>
|
||||
struct to_integer_type<1>
|
||||
{
|
||||
using type = int8_t;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int bytes>
|
||||
using to_integer_type = typename detail::to_integer_type<bytes>::type;
|
||||
|
||||
// host API (shoule come from codegen)
|
||||
float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp8") == 0)
|
||||
{
|
||||
using DataType = ck_tile::fp8_t;
|
||||
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
|
||||
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
using DataType = ck_tile::half_t;
|
||||
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
|
||||
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
else if(t.data_type.compare("fp32") == 0)
|
||||
{
|
||||
using DataType = float;
|
||||
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
|
||||
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)")
|
||||
.insert("shape", "2,3,4", "the shape of the input tensor")
|
||||
.insert("perm", "2,1,0", "permute perm")
|
||||
.insert("kname", "0", "t to 1 will print kernel name")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
}
|
||||
|
||||
// "1,2,3,4" -> vector{1,2,3,4}
|
||||
std::vector<ck_tile::index_t> decode_vec(std::string q_val)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
std::string::size_type pos = 0;
|
||||
std::vector<ck_tile::index_t> v;
|
||||
while(true)
|
||||
{
|
||||
auto found = q_val.find(',', pos);
|
||||
ck_tile::index_t n =
|
||||
_S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos));
|
||||
v.push_back(n);
|
||||
if(found == std::string::npos)
|
||||
{
|
||||
break;
|
||||
}
|
||||
pos = found + 1;
|
||||
}
|
||||
return v;
|
||||
#undef _S2I_
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
|
||||
auto shape = decode_vec(arg_parser.get_str("shape"));
|
||||
auto perm = decode_vec(arg_parser.get_str("perm"));
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
int seed = arg_parser.get_int("seed");
|
||||
|
||||
assert(shape.size() == perm.size());
|
||||
ck_tile::index_t rank = perm.size();
|
||||
if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks)
|
||||
{
|
||||
printf("rank %d permute is not support yet\n", rank);
|
||||
return false;
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<DataType> x(shape);
|
||||
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
|
||||
|
||||
std::vector<ck_tile::index_t> y_shape = [&]() {
|
||||
std::vector<ck_tile::index_t> tmp(rank, 0);
|
||||
// std::cout << "@@@@" << tmp << std::endl;
|
||||
for(int i = 0; i < static_cast<int>(rank); i++)
|
||||
{
|
||||
// std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" <<
|
||||
// static_cast<int>(rank)
|
||||
// << std::endl;
|
||||
tmp[i] = shape[perm[i]];
|
||||
}
|
||||
// std::cout << "@@@" << tmp << std::endl;
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
ck_tile::HostTensor<DataType> y(y_shape);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x.data());
|
||||
|
||||
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm
|
||||
<< std::flush;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
stream_warmup,
|
||||
stream_repeat};
|
||||
float ave_time = 0.f;
|
||||
auto run_permute = [&]() {
|
||||
permute_traits t;
|
||||
t.data_type = data_type;
|
||||
|
||||
permute_args a;
|
||||
a.p_src = x_buf.GetDeviceBuffer();
|
||||
a.p_dst = y_buf.GetDeviceBuffer();
|
||||
a.rank = rank;
|
||||
std::copy(shape.begin(), shape.end(), a.shape);
|
||||
std::copy(perm.begin(), perm.end(), a.perm);
|
||||
|
||||
return permute(t, a, stream_config);
|
||||
};
|
||||
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
|
||||
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
|
||||
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
|
||||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") ||
|
||||
arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")))
|
||||
{
|
||||
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
|
||||
{
|
||||
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
matrix_core_swizzle_traits t;
|
||||
t.data_type = data_type;
|
||||
t.permute = arg_parser.get_str("perm");
|
||||
|
||||
matrix_core_swizzle_args a;
|
||||
a.p_src = x_buf.GetDeviceBuffer();
|
||||
a.p_dst = y_buf.GetDeviceBuffer();
|
||||
a.batch = shape[0];
|
||||
|
||||
auto nr = shape[1];
|
||||
auto nw = shape[2];
|
||||
auto kr = shape[3];
|
||||
auto kw = shape[4];
|
||||
auto kv = shape[5];
|
||||
a.n = nr * nw;
|
||||
a.k = kr * kw * kv;
|
||||
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
|
||||
{
|
||||
t.inst = "16x16x16";
|
||||
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
|
||||
{
|
||||
t.inst = "32x32x8";
|
||||
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = run_permute();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
matrix_core_swizzle_traits t;
|
||||
t.data_type = data_type;
|
||||
t.permute = arg_parser.get_str("perm");
|
||||
|
||||
matrix_core_swizzle_args a;
|
||||
a.p_src = x_buf.GetDeviceBuffer();
|
||||
a.p_dst = y_buf.GetDeviceBuffer();
|
||||
a.batch = shape[0];
|
||||
a.n = shape[1] * shape[2] * shape[3];
|
||||
a.k = shape[4] * shape[5] * shape[6];
|
||||
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
|
||||
shape[4] % 8 == 0 && shape[1] % 2 == 0)
|
||||
{
|
||||
// 32x32x8 inst
|
||||
// perm=0,1,4,2,5,3,6
|
||||
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
|
||||
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
|
||||
|
||||
t.inst = "32x32x8";
|
||||
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
|
||||
shape[4] % 4 == 0 && shape[1] % 4 == 0)
|
||||
{
|
||||
// 16x16x16 inst
|
||||
// perm=0,1,4,2,5,3,6
|
||||
// y_shape=*,4x,4x,4,4,16,8
|
||||
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
|
||||
t.inst = "16x16x16";
|
||||
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
|
||||
|
||||
ave_time = matrix_core_swizzle(t, a, stream_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = run_permute();
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
ave_time = run_permute();
|
||||
}
|
||||
std::cout << ", time:" << ave_time << "ms" << std::flush;
|
||||
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
reference_permute(x, y, perm);
|
||||
#if 0
|
||||
if constexpr (std::is_same_v<float, DataType>){
|
||||
// using itype = to_integer_type<sizeof(DataType)>;
|
||||
fflush(stdout);
|
||||
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
|
||||
printf("%3.0f ", x.mData[zz]);
|
||||
}
|
||||
printf("->\n");
|
||||
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
|
||||
printf("%3.0f ", y.mData[zz]);
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
#endif
|
||||
ck_tile::HostTensor<DataType> y_dev(y.get_lengths());
|
||||
|
||||
y_buf.FromDevice(y_dev.data());
|
||||
|
||||
pass = std::equal(
|
||||
y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) {
|
||||
using itype = to_integer_type<sizeof(DataType)>;
|
||||
itype i_d = ck_tile::bit_cast<itype>(d);
|
||||
itype i_h = ck_tile::bit_cast<itype>(h);
|
||||
return i_d == i_h;
|
||||
});
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp32")
|
||||
{
|
||||
return run<float>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
19
example/ck_tile/06_permute/permute.hpp
Normal file
19
example/ck_tile/06_permute/permute.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/permute.hpp"
|
||||
#include <string>
|
||||
|
||||
struct permute_traits
|
||||
{
|
||||
std::string data_type;
|
||||
};
|
||||
|
||||
using permute_args = ck_tile::GenericPermuteHostArgs;
|
||||
|
||||
// host API
|
||||
float permute(permute_traits, permute_args, const ck_tile::stream_config&);
|
||||
34
example/ck_tile/06_permute/script/smoke_test.sh
Normal file
34
example/ck_tile/06_permute/script/smoke_test.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root
|
||||
BUILD=build
|
||||
EXE=$BUILD/bin/tile_example_permute
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
# mode=0
|
||||
# export HIP_VISIBLE_DEVICES=4
|
||||
if [ $# -ge 1 ] ; then
|
||||
set -x
|
||||
fi
|
||||
|
||||
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=2,8,16,8,4,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
|
||||
$EXE -prec=fp16 -shape=1,24,32,16,2,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
|
||||
|
||||
echo "------------------------------------------------------------------"
|
||||
|
||||
for prec in "fp8" "fp16" "fp32" ; do
|
||||
|
||||
$EXE -prec=$prec -shape=3,8 -perm=1,0 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=48,6,8 -perm=2,1,0 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=24,128,3 -perm=0,2,1 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=4,10,7,6 -perm=0,2,3,1 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=8,24,36,10 -perm=3,1,2,0 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=8,1,36,4 -perm=2,1,0,3 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=5,10,16,2,36,4 -perm=4,5,2,1,0,3 $COMMON_ARGS
|
||||
$EXE -prec=$prec -shape=2,32,8,3,6,2,5,4 -perm=5,2,4,7,1,6,3,0 $COMMON_ARGS
|
||||
echo "------------------------------------------------------------------"
|
||||
done
|
||||
Reference in New Issue
Block a user