[CK_TILE] add generic_permute (#1607)

This commit is contained in:
valarLip
2024-10-29 18:05:53 +08:00
committed by GitHub
parent 922e42a039
commit 9fbd72e97e
14 changed files with 1318 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp)
if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL)
# set(PERMUTE_USE_ALTERNATIVE_IMPL false)
set(PERMUTE_USE_ALTERNATIVE_IMPL true)
endif()
if(PERMUTE_USE_ALTERNATIVE_IMPL)
target_compile_options(tile_example_permute PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL)
target_sources(tile_example_permute PRIVATE alternative_impl/matrix_core_swizzle.cpp)
endif()
# target_compile_options(tile_example_permute PRIVATE -v --save-temps -Wno-gnu-line-marker)

View File

@@ -0,0 +1,46 @@
# permute
This folder contains example for permute kernel, which is similiar to [torch.permute](https://pytorch.org/docs/stable/generated/torch.permute.html) (combined with [torch.contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html)). Currently we implement a generic permute kernel that support up to rank 8 arbitrary permutation with a single kernel instance. Performance is not the first consideration, we prefer a simple and general kernel implementation using `ck_tile` in this example.
```
args:
-v weather do CPU validation or not (default:1)
-prec data type. fp16/bf16/fp32 (default:fp16)
-shape the shape of the input tensor (default:2,3,4)
-perm permute perm (default:2,1,0)
```
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_permute -j
```
This will result in an executable `build/bin/tile_example_permute`
## some examples
```
# torch
x=torch.randn(2,3,4,6)
y=x.permute(0,3,2,1).contiguous()
# ck_tile
./build/bin/tile_example_permute -shape=2,3,4,6 -perm=0,3,2,1
```
or you can try the smoke_test
```
# in the root of ck_tile, after you build this example
sh example/ck_tile/06_permute/script/smoke_test.sh
```
### alternative implementation
we have an alternative implementation under `alternative_impl/` folder, that can swizzle the tensor to be more friendly for data loading for matrix core layout. This can be enabled when dealing with a `rank-7` tensor, with a fixed pattern of either `0,1,4,2,5,3,6` or `0,1,2,4,5,3,6`. There are other shape limitation of this implementation, check the source code of `permute.cpp` for detail.
```
# example
./build/bin/tile_example_permute -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 # b_n0_k0_n1_k1_n2_k2
./build/bin/tile_example_permute -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 # b_n0_n1_k0_k1_n2_k2
```

View File

@@ -0,0 +1,98 @@
#include "matrix_core_swizzle.hpp"
#include "matrix_core_swizzle_kernel.hpp"
float matrix_core_swizzle(matrix_core_swizzle_traits t,
matrix_core_swizzle_args a,
const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
if(t.inst.compare("32x32x8") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
else if(t.inst.compare("16x16x16") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
}
return -1;
}

View File

@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "matrix_core_swizzle_kernel.hpp"
#include <string>
struct matrix_core_swizzle_traits
{
std::string data_type; // fp16 only
std::string inst; // 32x32x8, 16x16x16
std::string permute; //
};
using matrix_core_swizzle_args = matrix_core_swizzle_host_args;
// host API
float matrix_core_swizzle(matrix_core_swizzle_traits,
matrix_core_swizzle_args,
const ck_tile::stream_config&);

View File

@@ -0,0 +1,413 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
// if set to 1, slightly more instructions generated to calculate address
#ifndef MERGE_2D_013425
#define MERGE_2D_013425 0
#endif
enum class matrix_core_inst_enum
{
MFMA_32x32x8_F16 = 0,
MFMA_16x16x16_F16 = 1,
};
namespace detail {
template <matrix_core_inst_enum>
struct to_warp_gemm;
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_32x32x8_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8;
};
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16;
};
} // namespace detail
template <matrix_core_inst_enum Inst>
using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type;
// TODO: in below permute pattern, the last 3 dim is within wave
enum class matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
};
// assume this is B matrix, originally we have batch*n*k
// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4)
//
// 4(waves) 32(mfma_m lane)
// | |
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading)
// nr kr |
// nr 4 32 kr 2 8 2(klane)
//
// permute: 0,1,4,2,5,3,6
// or
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading)
// permute: 0,1,2,4,5,3,6
//
// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling
// for simplicity, only consider n/k is multiple of block-size
// independend host arg with no template
struct matrix_core_swizzle_host_args
{
const void* p_src;
void* p_dst;
int32_t batch;
int32_t n;
int32_t k;
};
// NOTE: this kernel could follow the style of generic permute kernel
// but here we pass in fixed layout as template arg and generate different kernel instance
// purposely
template <int BLOCK_SIZE_ = 256,
int NPerBlock_ = 256,
int KPerBlock_ = 128,
matrix_core_permute_style pstyle_ =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2,
matrix_core_inst_enum Inst_ = matrix_core_inst_enum::MFMA_32x32x8_F16>
struct matrix_core_swizzle_kernel
{
using karg = matrix_core_swizzle_host_args;
using harg = matrix_core_swizzle_host_args;
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = 4;
static constexpr int WavesPerBlock_K = 1;
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
static constexpr int NPerBlock = NPerBlock_;
static constexpr int KPerBlock = KPerBlock_;
static constexpr matrix_core_permute_style pstyle = pstyle_;
static constexpr matrix_core_inst_enum Inst = Inst_;
static constexpr ck_tile::index_t Alignment = 8;
karg a;
dim3 grids;
using WarpGemm = to_warp_gemm_t<Inst>;
__host__ matrix_core_swizzle_kernel(harg h)
{
a = h;
ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock;
ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock;
grids = dim3(ks, ns, h.batch);
}
__host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; }
__host__ void operator()(const ck_tile::stream_config& s) const
{
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
}
struct kernel
{
__device__ static constexpr auto get_src_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<N2>, sequence<K0>, sequence<K1>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<5, 3>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 4, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
__device__ static constexpr auto get_dst_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<K0>, sequence<N1>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<3>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 2, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<K0>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 3, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
else
{
// clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
#if MERGE_2D_013425
tile_distribution_encoding<
sequence<1>,// 0 R
// major 1 2
// minor 0 1 2 0 1 2 3
tuple<sequence<Nr_y, Nr_p, Nw>, sequence<Kr_y, Kr_p, Kw, Kv>>, // H
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<2, 1>>, // p major
tuple<sequence<1 , 1>, sequence<2, 2>>, // p minor
// Nr_y Kr_y Kv
sequence<1, 2, 2>, // Y major
sequence<0, 0, 3>>{}); // y minor
#else
tile_distribution_encoding<
sequence<1>,// 0 R
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>, // H
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<3, 3>>, // p major
tuple<sequence<1 , 1>, sequence<0, 1>>, // p minor
// Nr_y Kr_y Kv
sequence<1, 2, 3>, // Y major
sequence<0, 0, 2>>{}); // y minor
#endif
// clang-format on
}
}
__device__ void operator()(karg a_)
{
using namespace ck_tile;
index_t i_k = blockIdx.x;
index_t i_n = blockIdx.y;
index_t i_b = blockIdx.z;
constexpr index_t k2 = Alignment;
constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1 = BLOCK_SIZE / get_warp_size();
const index_t k0 = a_.k / (k1 * k2);
const index_t n0 = a_.n / (n1 * n2);
constexpr index_t k2_tile = Alignment;
constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size();
constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile);
constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile);
const fp16_t* p_src = reinterpret_cast<const fp16_t*>(a_.p_src) + i_b * a_.k * a_.n;
fp16_t* p_dst = reinterpret_cast<fp16_t*>(a_.p_dst) + i_b * a_.k * a_.n;
const auto src_view = [&]() {
const auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_src,
make_tuple(n0, n1, n2, k0, k1, k2),
number<Alignment>{}); // control vector load
return tmp;
}();
const auto src_window = make_tile_window(src_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<n2_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0},
get_src_dist());
auto dst_view = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, k0, n1, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, n1, k0, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
else
{
#if MERGE_2D_013425
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
// constexpr index_t waveflatten = kw*nw*kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, number<kw>{}, number<nw>{}, number<kv>{}),
number<Alignment>{}); // control vector load
auto tmp_1 = transform_tensor_view(
tmp,
make_tuple(
make_merge_transform(make_tuple(nr, number<nw>{})),
make_merge_transform(make_tuple(kr, number<kw>{}, number<kv>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1;
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten = kw * nw * kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, waveflatten),
number<Alignment>{}); // control vector load
return tmp;
#endif
}
}();
auto dst_window = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<k0_tile>{},
number<n1_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0},
get_dst_dist());
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0},
get_dst_dist());
}
else
{
#if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist());
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten_tile = kw * nw * kv;
constexpr index_t nr_tile = NPerBlock / nw;
constexpr index_t kr_tile = KPerBlock / (kw * kv);
return make_tile_window(dst_view,
make_tuple(number<nr_tile>{},
number<kr_tile>{},
number<waveflatten_tile>{}),
{i_n * nr_tile, i_k * kr_tile, 0},
get_dst_dist());
#endif
}
}();
// actual load store
auto src_tile = load_tile(src_window);
// now we only swap the distribution from src to dst, no extra movement occurs
auto dst_tile = make_static_distributed_tensor<fp16_t>(get_dst_dist());
dst_tile.get_thread_buffer() = src_tile.get_thread_buffer();
// final store
store_tile(dst_window, dst_tile);
}
};
};

View File

@@ -0,0 +1,411 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include <array>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
namespace detail {
template <int bytes>
struct to_integer_type;
template <>
struct to_integer_type<4>
{
using type = int32_t;
};
template <>
struct to_integer_type<2>
{
using type = int16_t;
};
template <>
struct to_integer_type<1>
{
using type = int8_t;
};
} // namespace detail
template <int bytes>
using to_integer_type = typename detail::to_integer_type<bytes>::type;
// host API (shoule come from codegen)
float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp8") == 0)
{
using DataType = ck_tile::fp8_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp16") == 0)
{
using DataType = ck_tile::half_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp32") == 0)
{
using DataType = float;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)")
.insert("shape", "2,3,4", "the shape of the input tensor")
.insert("perm", "2,1,0", "permute perm")
.insert("kname", "0", "t to 1 will print kernel name")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
// "1,2,3,4" -> vector{1,2,3,4}
std::vector<ck_tile::index_t> decode_vec(std::string q_val)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
std::string::size_type pos = 0;
std::vector<ck_tile::index_t> v;
while(true)
{
auto found = q_val.find(',', pos);
ck_tile::index_t n =
_S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos));
v.push_back(n);
if(found == std::string::npos)
{
break;
}
pos = found + 1;
}
return v;
#undef _S2I_
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto shape = decode_vec(arg_parser.get_str("shape"));
auto perm = decode_vec(arg_parser.get_str("perm"));
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
int seed = arg_parser.get_int("seed");
assert(shape.size() == perm.size());
ck_tile::index_t rank = perm.size();
if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks)
{
printf("rank %d permute is not support yet\n", rank);
return false;
}
ck_tile::HostTensor<DataType> x(shape);
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
// std::cout << "@@@@" << tmp << std::endl;
for(int i = 0; i < static_cast<int>(rank); i++)
{
// std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" <<
// static_cast<int>(rank)
// << std::endl;
tmp[i] = shape[perm[i]];
}
// std::cout << "@@@" << tmp << std::endl;
return tmp;
}();
ck_tile::HostTensor<DataType> y(y_shape);
ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes());
x_buf.ToDevice(x.data());
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm
<< std::flush;
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
stream_warmup,
stream_repeat};
float ave_time = 0.f;
auto run_permute = [&]() {
permute_traits t;
t.data_type = data_type;
permute_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.rank = rank;
std::copy(shape.begin(), shape.end(), a.shape);
std::copy(perm.begin(), perm.end(), a.perm);
return permute(t, a, stream_config);
};
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")))
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
auto nr = shape[1];
auto nw = shape[2];
auto kr = shape[3];
auto kw = shape[4];
auto kv = shape[5];
a.n = nr * nw;
a.k = kr * kw * kv;
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
else
{
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
a.n = shape[1] * shape[2] * shape[3];
a.k = shape[4] * shape[5] * shape[6];
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
shape[4] % 8 == 0 && shape[1] % 2 == 0)
{
// 32x32x8 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
shape[4] % 4 == 0 && shape[1] % 4 == 0)
{
// 16x16x16 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,4x,4x,4,4,16,8
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
}
else
#endif
{
ave_time = run_permute();
}
std::cout << ", time:" << ave_time << "ms" << std::flush;
bool pass = true;
if(do_validation)
{
reference_permute(x, y, perm);
#if 0
if constexpr (std::is_same_v<float, DataType>){
// using itype = to_integer_type<sizeof(DataType)>;
fflush(stdout);
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
printf("%3.0f ", x.mData[zz]);
}
printf("->\n");
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
printf("%3.0f ", y.mData[zz]);
}
fflush(stdout);
}
#endif
ck_tile::HostTensor<DataType> y_dev(y.get_lengths());
y_buf.FromDevice(y_dev.data());
pass = std::equal(
y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) {
using itype = to_integer_type<sizeof(DataType)>;
itype i_d = ck_tile::bit_cast<itype>(d);
itype i_h = ck_tile::bit_cast<itype>(h);
return i_d == i_h;
});
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::endl;
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp8")
{
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp32")
{
return run<float>(arg_parser) ? 0 : -2;
}
return -3;
}

View File

@@ -0,0 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/permute.hpp"
#include <string>
struct permute_traits
{
std::string data_type;
};
using permute_args = ck_tile::GenericPermuteHostArgs;
// host API
float permute(permute_traits, permute_args, const ck_tile::stream_config&);

View File

@@ -0,0 +1,34 @@
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_permute
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
if [ $# -ge 1 ] ; then
set -x
fi
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=2,8,16,8,4,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
$EXE -prec=fp16 -shape=1,24,32,16,2,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
echo "------------------------------------------------------------------"
for prec in "fp8" "fp16" "fp32" ; do
$EXE -prec=$prec -shape=3,8 -perm=1,0 $COMMON_ARGS
$EXE -prec=$prec -shape=48,6,8 -perm=2,1,0 $COMMON_ARGS
$EXE -prec=$prec -shape=24,128,3 -perm=0,2,1 $COMMON_ARGS
$EXE -prec=$prec -shape=4,10,7,6 -perm=0,2,3,1 $COMMON_ARGS
$EXE -prec=$prec -shape=8,24,36,10 -perm=3,1,2,0 $COMMON_ARGS
$EXE -prec=$prec -shape=8,1,36,4 -perm=2,1,0,3 $COMMON_ARGS
$EXE -prec=$prec -shape=5,10,16,2,36,4 -perm=4,5,2,1,0,3 $COMMON_ARGS
$EXE -prec=$prec -shape=2,32,8,3,6,2,5,4 -perm=5,2,4,7,1,6,3,0 $COMMON_ARGS
echo "------------------------------------------------------------------"
done