From a712223d4d873e5a03cdcf1790c4cb102f35e505 Mon Sep 17 00:00:00 2001 From: valarLip <103567126+valarLip@users.noreply.github.com> Date: Tue, 29 Oct 2024 18:05:53 +0800 Subject: [PATCH] [CK_TILE] add generic_permute (#1607) [ROCm/composable_kernel commit: 9fbd72e97e34f530ae370527755b655bf390d9ee] --- example/ck_tile/06_permute/CMakeLists.txt | 13 + example/ck_tile/06_permute/README.md | 46 ++ .../alternative_impl/matrix_core_swizzle.cpp | 98 +++++ .../alternative_impl/matrix_core_swizzle.hpp | 20 + .../matrix_core_swizzle_kernel.hpp | 413 ++++++++++++++++++ example/ck_tile/06_permute/permute.cpp | 411 +++++++++++++++++ example/ck_tile/06_permute/permute.hpp | 19 + .../ck_tile/06_permute/script/smoke_test.sh | 34 ++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/host.hpp | 1 + .../host/reference/reference_permute.hpp | 57 +++ include/ck_tile/ops/permute.hpp | 8 + .../permute/kernel/generic_permute_kernel.hpp | 169 +++++++ .../pipeline/generic_petmute_problem.hpp | 28 ++ 14 files changed, 1318 insertions(+) create mode 100644 example/ck_tile/06_permute/CMakeLists.txt create mode 100644 example/ck_tile/06_permute/README.md create mode 100644 example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp create mode 100644 example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.hpp create mode 100644 example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp create mode 100644 example/ck_tile/06_permute/permute.cpp create mode 100644 example/ck_tile/06_permute/permute.hpp create mode 100644 example/ck_tile/06_permute/script/smoke_test.sh create mode 100644 include/ck_tile/host/reference/reference_permute.hpp create mode 100644 include/ck_tile/ops/permute.hpp create mode 100644 include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp create mode 100644 include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp diff --git a/example/ck_tile/06_permute/CMakeLists.txt b/example/ck_tile/06_permute/CMakeLists.txt new file mode 100644 index 0000000000..327fceb685 --- /dev/null +++ b/example/ck_tile/06_permute/CMakeLists.txt @@ -0,0 +1,13 @@ +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp) + +if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL) +# set(PERMUTE_USE_ALTERNATIVE_IMPL false) +set(PERMUTE_USE_ALTERNATIVE_IMPL true) +endif() +if(PERMUTE_USE_ALTERNATIVE_IMPL) +target_compile_options(tile_example_permute PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL) +target_sources(tile_example_permute PRIVATE alternative_impl/matrix_core_swizzle.cpp) +endif() +# target_compile_options(tile_example_permute PRIVATE -v --save-temps -Wno-gnu-line-marker) diff --git a/example/ck_tile/06_permute/README.md b/example/ck_tile/06_permute/README.md new file mode 100644 index 0000000000..03bd810ff4 --- /dev/null +++ b/example/ck_tile/06_permute/README.md @@ -0,0 +1,46 @@ +# permute + +This folder contains example for permute kernel, which is similiar to [torch.permute](https://pytorch.org/docs/stable/generated/torch.permute.html) (combined with [torch.contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html)). Currently we implement a generic permute kernel that support up to rank 8 arbitrary permutation with a single kernel instance. Performance is not the first consideration, we prefer a simple and general kernel implementation using `ck_tile` in this example. + + +``` +args: + -v weather do CPU validation or not (default:1) + -prec data type. fp16/bf16/fp32 (default:fp16) + -shape the shape of the input tensor (default:2,3,4) + -perm permute perm (default:2,1,0) +``` + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_permute -j +``` +This will result in an executable `build/bin/tile_example_permute` + + +## some examples +``` +# torch +x=torch.randn(2,3,4,6) +y=x.permute(0,3,2,1).contiguous() + +# ck_tile +./build/bin/tile_example_permute -shape=2,3,4,6 -perm=0,3,2,1 +``` + +or you can try the smoke_test +``` +# in the root of ck_tile, after you build this example +sh example/ck_tile/06_permute/script/smoke_test.sh +``` + +### alternative implementation +we have an alternative implementation under `alternative_impl/` folder, that can swizzle the tensor to be more friendly for data loading for matrix core layout. This can be enabled when dealing with a `rank-7` tensor, with a fixed pattern of either `0,1,4,2,5,3,6` or `0,1,2,4,5,3,6`. There are other shape limitation of this implementation, check the source code of `permute.cpp` for detail. +``` +# example +./build/bin/tile_example_permute -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 # b_n0_k0_n1_k1_n2_k2 +./build/bin/tile_example_permute -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 # b_n0_n1_k0_k1_n2_k2 +``` diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp new file mode 100644 index 0000000000..93c662a288 --- /dev/null +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp @@ -0,0 +1,98 @@ +#include "matrix_core_swizzle.hpp" +#include "matrix_core_swizzle_kernel.hpp" + +float matrix_core_swizzle(matrix_core_swizzle_traits t, + matrix_core_swizzle_args a, + const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + if(t.inst.compare("32x32x8") == 0) + { + constexpr int BLOCK_SIZE = 256; + constexpr int NPerBlock = 256; + constexpr int KPerBlock = 128; + constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16; + if(t.permute.compare("0,1,4,2,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,2,4,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,3,4,2,5") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + } + else if(t.inst.compare("16x16x16") == 0) + { + constexpr int BLOCK_SIZE = 256; + constexpr int NPerBlock = 256; + constexpr int KPerBlock = 128; + constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16; + if(t.permute.compare("0,1,4,2,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,2,4,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,3,4,2,5") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + } + } + return -1; +} diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.hpp new file mode 100644 index 0000000000..e1ecdbbe64 --- /dev/null +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "matrix_core_swizzle_kernel.hpp" +#include + +struct matrix_core_swizzle_traits +{ + std::string data_type; // fp16 only + std::string inst; // 32x32x8, 16x16x16 + std::string permute; // +}; + +using matrix_core_swizzle_args = matrix_core_swizzle_host_args; + +// host API +float matrix_core_swizzle(matrix_core_swizzle_traits, + matrix_core_swizzle_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp new file mode 100644 index 0000000000..60ac103ec3 --- /dev/null +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -0,0 +1,413 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +// if set to 1, slightly more instructions generated to calculate address +#ifndef MERGE_2D_013425 +#define MERGE_2D_013425 0 +#endif + +enum class matrix_core_inst_enum +{ + MFMA_32x32x8_F16 = 0, + MFMA_16x16x16_F16 = 1, +}; + +namespace detail { +template +struct to_warp_gemm; + +template <> +struct to_warp_gemm +{ + using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8; +}; + +template <> +struct to_warp_gemm +{ + using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16; +}; +} // namespace detail +template +using to_warp_gemm_t = typename detail::to_warp_gemm::type; + +// TODO: in below permute pattern, the last 3 dim is within wave +enum class matrix_core_permute_style +{ + permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 + permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 + permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, +}; + +// assume this is B matrix, originally we have batch*n*k +// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 +// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4) +// +// 4(waves) 32(mfma_m lane) +// | | +// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading) +// nr kr | +// nr 4 32 kr 2 8 2(klane) +// +// permute: 0,1,4,2,5,3,6 +// or +// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading) +// permute: 0,1,2,4,5,3,6 +// +// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling +// for simplicity, only consider n/k is multiple of block-size + +// independend host arg with no template +struct matrix_core_swizzle_host_args +{ + const void* p_src; + void* p_dst; + int32_t batch; + int32_t n; + int32_t k; +}; + +// NOTE: this kernel could follow the style of generic permute kernel +// but here we pass in fixed layout as template arg and generate different kernel instance +// purposely +template +struct matrix_core_swizzle_kernel +{ + using karg = matrix_core_swizzle_host_args; + using harg = matrix_core_swizzle_host_args; + + static constexpr int BLOCK_SIZE = BLOCK_SIZE_; + static constexpr int WavesPerBlock_N = 4; + static constexpr int WavesPerBlock_K = 1; + static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE); + static constexpr int NPerBlock = NPerBlock_; + static constexpr int KPerBlock = KPerBlock_; + static constexpr matrix_core_permute_style pstyle = pstyle_; + static constexpr matrix_core_inst_enum Inst = Inst_; + + static constexpr ck_tile::index_t Alignment = 8; + karg a; + dim3 grids; + + using WarpGemm = to_warp_gemm_t; + + __host__ matrix_core_swizzle_kernel(harg h) + { + a = h; + ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock; + ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock; + grids = dim3(ks, ns, h.batch); + } + + __host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; } + + __host__ void operator()(const ck_tile::stream_config& s) const + { + ck_tile::kentry<<>>(a); + } + + struct kernel + { + __device__ static constexpr auto get_src_dist() + { + using namespace ck_tile; + constexpr index_t K2 = Alignment; + constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t N1 = BLOCK_SIZE / get_warp_size(); + + static_assert(NPerBlock % (N1 * N2) == 0); + static_assert(KPerBlock % (K1 * K2) == 0); + + constexpr index_t K0 = KPerBlock / (K1 * K2); + constexpr index_t N0 = NPerBlock / (N1 * N2); + + // clang-format off + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>,// 0 + // 1 2 3 4 5 6 + tuple, sequence, sequence, sequence, sequence, sequence>, + + // N1 K1 N2 + tuple, sequence<5, 3>>, + tuple, sequence<0, 0>>, + + // N0 K0 K2 + sequence<1, 4, 6>, + sequence<0, 0, 0>>{}); + // clang-format on + } + __device__ static constexpr auto get_dst_dist() + { + using namespace ck_tile; + constexpr index_t K2 = Alignment; + constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t N1 = BLOCK_SIZE / get_warp_size(); + + static_assert(NPerBlock % (N1 * N2) == 0); + static_assert(KPerBlock % (K1 * K2) == 0); + + constexpr index_t K0 = KPerBlock / (K1 * K2); + constexpr index_t N0 = NPerBlock / (N1 * N2); + + if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2) + { + // clang-format off + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>,// 0 + // 1 2 3 4 5 6 + tuple, sequence, sequence, sequence, sequence, sequence>, + + // N1 K1 N2 + tuple, sequence<4, 5>>, + tuple, sequence<0, 0>>, + + // N0 K0 K2 + sequence<1, 2, 6>, + sequence<0, 0, 0>>{}); + // clang-format on + } + else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2) + { + // clang-format off + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>,// 0 + // 1 2 3 4 5 6 + tuple, sequence, sequence, sequence, sequence, sequence>, + + // N1 K1 N2 + tuple, sequence<4, 5>>, + tuple, sequence<0, 0>>, + + // N0 K0 K2 + sequence<1, 3, 6>, + sequence<0, 0, 0>>{}); + // clang-format on + } + else + { + // clang-format off + // permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten + constexpr index_t Kv = Alignment; + constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + + static_assert(KPerBlock % (K1 * K2) == 0); + constexpr index_t Nr = NPerBlock / Nw; + constexpr index_t Kr = KPerBlock / (Kv * Kw); + + constexpr index_t Nr_p = WavesPerBlock_N; + constexpr index_t Kr_p = WavesPerBlock_K; + constexpr index_t Nr_y = Nr / Nr_p; + constexpr index_t Kr_y = Kr / Kr_p; + + return make_static_tile_distribution( +#if MERGE_2D_013425 + tile_distribution_encoding< + sequence<1>,// 0 R + // major 1 2 + // minor 0 1 2 0 1 2 3 + tuple, sequence>, // H + + // Nr_p, Kr_p Kw Nw + tuple, sequence<2, 1>>, // p major + tuple, sequence<2, 2>>, // p minor + + // Nr_y Kr_y Kv + sequence<1, 2, 2>, // Y major + sequence<0, 0, 3>>{}); // y minor +#else + tile_distribution_encoding< + sequence<1>,// 0 R + // major 1 2 3 + // minor 0 1 0 1 0 1 2 + tuple, sequence, sequence>, // H + + // Nr_p, Kr_p Kw Nw + tuple, sequence<3, 3>>, // p major + tuple, sequence<0, 1>>, // p minor + + // Nr_y Kr_y Kv + sequence<1, 2, 3>, // Y major + sequence<0, 0, 2>>{}); // y minor +#endif + // clang-format on + } + } + + __device__ void operator()(karg a_) + { + using namespace ck_tile; + index_t i_k = blockIdx.x; + index_t i_n = blockIdx.y; + index_t i_b = blockIdx.z; + + constexpr index_t k2 = Alignment; + constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t n1 = BLOCK_SIZE / get_warp_size(); + const index_t k0 = a_.k / (k1 * k2); + const index_t n0 = a_.n / (n1 * n2); + + constexpr index_t k2_tile = Alignment; + constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size(); + constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile); + constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile); + + const fp16_t* p_src = reinterpret_cast(a_.p_src) + i_b * a_.k * a_.n; + fp16_t* p_dst = reinterpret_cast(a_.p_dst) + i_b * a_.k * a_.n; + + const auto src_view = [&]() { + const auto tmp = make_naive_tensor_view_packed( + p_src, + make_tuple(n0, n1, n2, k0, k1, k2), + number{}); // control vector load + return tmp; + }(); + + const auto src_window = make_tile_window(src_view, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + {i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0}, + get_src_dist()); + + auto dst_view = [&]() { + if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2) + { + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(n0, k0, n1, k1, n2, k2), + number{}); // control vector load + return tmp; + } + else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2) + { + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(n0, n1, k0, k1, n2, k2), + number{}); // control vector load + return tmp; + } + else + { +#if MERGE_2D_013425 + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + // constexpr index_t waveflatten = kw*nw*kv; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(nr, kr, number{}, number{}, number{}), + number{}); // control vector load + auto tmp_1 = transform_tensor_view( + tmp, + make_tuple( + make_merge_transform(make_tuple(nr, number{})), + make_merge_transform(make_tuple(kr, number{}, number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return tmp_1; +#else + // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t waveflatten = kw * nw * kv; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(nr, kr, waveflatten), + number{}); // control vector load + return tmp; +#endif + } + }(); + + auto dst_window = [&]() { + if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2) + { + return make_tile_window(dst_view, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + {i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0}, + get_dst_dist()); + } + else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2) + { + return make_tile_window(dst_view, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + {i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0}, + get_dst_dist()); + } + else + { +#if MERGE_2D_013425 + // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv + return make_tile_window(dst_view, + make_tuple(number{}, number{}), + {i_n * NPerBlock, i_k * KPerBlock}, + get_dst_dist()); +#else + // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t waveflatten_tile = kw * nw * kv; + constexpr index_t nr_tile = NPerBlock / nw; + constexpr index_t kr_tile = KPerBlock / (kw * kv); + return make_tile_window(dst_view, + make_tuple(number{}, + number{}, + number{}), + {i_n * nr_tile, i_k * kr_tile, 0}, + get_dst_dist()); +#endif + } + }(); + + // actual load store + auto src_tile = load_tile(src_window); + + // now we only swap the distribution from src to dst, no extra movement occurs + auto dst_tile = make_static_distributed_tensor(get_dst_dist()); + dst_tile.get_thread_buffer() = src_tile.get_thread_buffer(); + + // final store + store_tile(dst_window, dst_tile); + } + }; +}; diff --git a/example/ck_tile/06_permute/permute.cpp b/example/ck_tile/06_permute/permute.cpp new file mode 100644 index 0000000000..af95b64e69 --- /dev/null +++ b/example/ck_tile/06_permute/permute.cpp @@ -0,0 +1,411 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "permute.hpp" +#include "ck_tile/host.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef PERMUTE_USE_ALTERNATIVE_IMPL +#include "alternative_impl/matrix_core_swizzle.hpp" +#endif + +namespace detail { +template +struct to_integer_type; + +template <> +struct to_integer_type<4> +{ + using type = int32_t; +}; +template <> +struct to_integer_type<2> +{ + using type = int16_t; +}; +template <> +struct to_integer_type<1> +{ + using type = int8_t; +}; +} // namespace detail + +template +using to_integer_type = typename detail::to_integer_type::type; + +// host API (shoule come from codegen) +float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp8") == 0) + { + using DataType = ck_tile::fp8_t; + using PipelineProblem = ck_tile::GenericPermuteProblem; + using Kernel = ck_tile::GenericPermute; + + auto kargs = Kernel::MakeKargs(a); + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + else if(t.data_type.compare("fp16") == 0) + { + using DataType = ck_tile::half_t; + using PipelineProblem = ck_tile::GenericPermuteProblem; + using Kernel = ck_tile::GenericPermute; + + auto kargs = Kernel::MakeKargs(a); + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + else if(t.data_type.compare("fp32") == 0) + { + using DataType = float; + using PipelineProblem = ck_tile::GenericPermuteProblem; + using Kernel = ck_tile::GenericPermute; + + auto kargs = Kernel::MakeKargs(a); + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + + return 0; +} + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)") + .insert("shape", "2,3,4", "the shape of the input tensor") + .insert("perm", "2,1,0", "permute perm") + .insert("kname", "0", "t to 1 will print kernel name") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +// "1,2,3,4" -> vector{1,2,3,4} +std::vector decode_vec(std::string q_val) +{ +#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) + std::string::size_type pos = 0; + std::vector v; + while(true) + { + auto found = q_val.find(',', pos); + ck_tile::index_t n = + _S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos)); + v.push_back(n); + if(found == std::string::npos) + { + break; + } + pos = found + 1; + } + return v; +#undef _S2I_ +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + + auto shape = decode_vec(arg_parser.get_str("shape")); + auto perm = decode_vec(arg_parser.get_str("perm")); + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + int seed = arg_parser.get_int("seed"); + + assert(shape.size() == perm.size()); + ck_tile::index_t rank = perm.size(); + if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks) + { + printf("rank %d permute is not support yet\n", rank); + return false; + } + + ck_tile::HostTensor x(shape); + ck_tile::FillUniformDistributionIntegerValue{-15, 15, seed}(x); + + std::vector y_shape = [&]() { + std::vector tmp(rank, 0); + // std::cout << "@@@@" << tmp << std::endl; + for(int i = 0; i < static_cast(rank); i++) + { + // std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" << + // static_cast(rank) + // << std::endl; + tmp[i] = shape[perm[i]]; + } + // std::cout << "@@@" << tmp << std::endl; + return tmp; + }(); + + ck_tile::HostTensor y(y_shape); + + ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x.data()); + + std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm + << std::flush; + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat}; + float ave_time = 0.f; + auto run_permute = [&]() { + permute_traits t; + t.data_type = data_type; + + permute_args a; + a.p_src = x_buf.GetDeviceBuffer(); + a.p_dst = y_buf.GetDeviceBuffer(); + a.rank = rank; + std::copy(shape.begin(), shape.end(), a.shape); + std::copy(perm.begin(), perm.end(), a.perm); + + return permute(t, a, stream_config); + }; +#ifdef PERMUTE_USE_ALTERNATIVE_IMPL + // batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 + if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") || + arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") || + arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))) + { + if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")) + { + // permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + matrix_core_swizzle_traits t; + t.data_type = data_type; + t.permute = arg_parser.get_str("perm"); + + matrix_core_swizzle_args a; + a.p_src = x_buf.GetDeviceBuffer(); + a.p_dst = y_buf.GetDeviceBuffer(); + a.batch = shape[0]; + + auto nr = shape[1]; + auto nw = shape[2]; + auto kr = shape[3]; + auto kw = shape[4]; + auto kv = shape[5]; + a.n = nr * nw; + a.k = kr * kw * kv; + if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0) + { + t.inst = "16x16x16"; + std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0) + { + t.inst = "32x32x8"; + std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else + { + ave_time = run_permute(); + } + } + else + { + matrix_core_swizzle_traits t; + t.data_type = data_type; + t.permute = arg_parser.get_str("perm"); + + matrix_core_swizzle_args a; + a.p_src = x_buf.GetDeviceBuffer(); + a.p_dst = y_buf.GetDeviceBuffer(); + a.batch = shape[0]; + a.n = shape[1] * shape[2] * shape[3]; + a.k = shape[4] * shape[5] * shape[6]; + if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 && + shape[4] % 8 == 0 && shape[1] % 2 == 0) + { + // 32x32x8 inst + // perm=0,1,4,2,5,3,6 + // y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8) + // shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8) + + t.inst = "32x32x8"; + std::cout << ", matrix_core_swizzle_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 && + shape[4] % 4 == 0 && shape[1] % 4 == 0) + { + // 16x16x16 inst + // perm=0,1,4,2,5,3,6 + // y_shape=*,4x,4x,4,4,16,8 + // shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8) + t.inst = "16x16x16"; + std::cout << ", matrix_core_swizzle_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else + { + ave_time = run_permute(); + } + } + } + else +#endif + { + ave_time = run_permute(); + } + std::cout << ", time:" << ave_time << "ms" << std::flush; + + bool pass = true; + if(do_validation) + { + reference_permute(x, y, perm); +#if 0 + if constexpr (std::is_same_v){ + // using itype = to_integer_type; + fflush(stdout); + for(int zz = 0; zz < static_cast(x.get_element_size()); zz++ ) { + printf("%3.0f ", x.mData[zz]); + } + printf("->\n"); + for(int zz = 0; zz < static_cast(x.get_element_size()); zz++ ) { + printf("%3.0f ", y.mData[zz]); + } + fflush(stdout); + } +#endif + ck_tile::HostTensor y_dev(y.get_lengths()); + + y_buf.FromDevice(y_dev.data()); + + pass = std::equal( + y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) { + using itype = to_integer_type; + itype i_d = ck_tile::bit_cast(d); + itype i_h = ck_tile::bit_cast(h); + return i_d == i_h; + }); + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + + std::cout << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp32") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/06_permute/permute.hpp b/example/ck_tile/06_permute/permute.hpp new file mode 100644 index 0000000000..304da4dc97 --- /dev/null +++ b/example/ck_tile/06_permute/permute.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/permute.hpp" +#include + +struct permute_traits +{ + std::string data_type; +}; + +using permute_args = ck_tile::GenericPermuteHostArgs; + +// host API +float permute(permute_traits, permute_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/06_permute/script/smoke_test.sh b/example/ck_tile/06_permute/script/smoke_test.sh new file mode 100644 index 0000000000..793e52d2bb --- /dev/null +++ b/example/ck_tile/06_permute/script/smoke_test.sh @@ -0,0 +1,34 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_permute +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 +if [ $# -ge 1 ] ; then + set -x +fi + +$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=2,8,16,8,4,8 -perm=0,1,3,4,2,5 $COMMON_ARGS +$EXE -prec=fp16 -shape=1,24,32,16,2,8 -perm=0,1,3,4,2,5 $COMMON_ARGS + +echo "------------------------------------------------------------------" + +for prec in "fp8" "fp16" "fp32" ; do + +$EXE -prec=$prec -shape=3,8 -perm=1,0 $COMMON_ARGS +$EXE -prec=$prec -shape=48,6,8 -perm=2,1,0 $COMMON_ARGS +$EXE -prec=$prec -shape=24,128,3 -perm=0,2,1 $COMMON_ARGS +$EXE -prec=$prec -shape=4,10,7,6 -perm=0,2,3,1 $COMMON_ARGS +$EXE -prec=$prec -shape=8,24,36,10 -perm=3,1,2,0 $COMMON_ARGS +$EXE -prec=$prec -shape=8,1,36,4 -perm=2,1,0,3 $COMMON_ARGS +$EXE -prec=$prec -shape=5,10,16,2,36,4 -perm=4,5,2,1,0,3 $COMMON_ARGS +$EXE -prec=$prec -shape=2,32,8,3,6,2,5,4 -perm=5,2,4,7,1,6,3,0 $COMMON_ARGS +echo "------------------------------------------------------------------" +done diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 366fb18a0f..c85e313413 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -7,5 +7,6 @@ add_subdirectory(02_layernorm2d) add_subdirectory(03_gemm) add_subdirectory(04_img2col) add_subdirectory(05_reduce) +add_subdirectory(06_permute) add_subdirectory(09_topk_softmax) diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index e17d7c22a2..a17ce751c2 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -22,6 +22,7 @@ #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" +#include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_topk.hpp" diff --git a/include/ck_tile/host/reference/reference_permute.hpp b/include/ck_tile/host/reference/reference_permute.hpp new file mode 100644 index 0000000000..1c82483407 --- /dev/null +++ b/include/ck_tile/host/reference/reference_permute.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include +#include +#include + +namespace ck_tile { + +/* + this will do permute + contiguous like functionality in pytorch +*/ +template +CK_TILE_HOST void +reference_permute(const HostTensor& x, HostTensor& y, std::vector dims) +{ + const auto x_len = x.mDesc.get_lengths(); + const auto y_len = y.mDesc.get_lengths(); + assert(x_len.size() == y_len.size()); + index_t rank = x_len.size(); + const auto x_elm = std::accumulate(x_len.begin(), x_len.end(), 1, std::multiplies()); + const auto y_elm = std::accumulate(y_len.begin(), y_len.end(), 1, std::multiplies()); + assert(x_elm == y_elm); + (void)y_elm; + + auto f = [&](auto i_element) { + std::vector y_coord = [&]() { + std::vector tmp(rank, 0); + size_t r = i_element; + for(index_t i = rank - 1; i >= 0; i--) + { + tmp[i] = r % y_len[i]; + r = r / y_len[i]; + } + return tmp; + }(); + + std::vector x_coord = [&]() { + std::vector tmp(rank, 0); + for(index_t i = 0; i < rank; i++) + { + tmp[dims[i]] = y_coord[i]; + } + return tmp; + }(); + + // do permute + y(y_coord) = x(x_coord); + }; + + make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp new file mode 100644 index 0000000000..ee8c693727 --- /dev/null +++ b/include/ck_tile/ops/permute.hpp @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" +#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp b/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp new file mode 100644 index 0000000000..1c5cc4a11a --- /dev/null +++ b/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +// #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" + +namespace ck_tile { + +/* independent host side argument, no template + */ +struct GenericPermuteHostArgs +{ + static constexpr index_t kMaxRanks = 8; // TODO: hardcoded + + const void* p_src; + void* p_dst; + index_t rank; + index_t shape[kMaxRanks]; // input shape + index_t perm[kMaxRanks]; // permute index +}; + +/* +simulate torch.permute: +x_ = x_.view(x.shape[0], + x.shape[1]//16, 16, + x.shape[2]//32, 4, 8) +x_ = x_.permute(0,1,3,4,2,5) +x_ = x_.contiguous() +x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);// + +this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks +dim of permutation, with a single kernel + +*/ +template +struct GenericPermute +{ + using Problem = ck_tile::remove_cvref_t; + + using DataType = remove_cvref_t; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMaxRanks = Problem::kMaxRanks; + static constexpr bool KeepLastDim = Problem::KeepLastDim; + + struct __attribute__((packed)) Kargs + { + const void* p_src; + void* p_dst; + // index_t rank; + index_t num_elements; + index_t perm_length[kMaxRanks]; // tensor length after permutation + index_t perm_stride[kMaxRanks]; // tensor stride after permutation + }; + + CK_TILE_HOST static constexpr index_t TotalElements(const GenericPermuteHostArgs& h) + { + index_t n = 1; + for(auto i = 0; i < h.rank; i++) + { + n *= h.shape[i]; + } + return n; + } + + CK_TILE_HOST static constexpr Kargs MakeKargs(const GenericPermuteHostArgs& h) + { + Kargs a; + a.p_src = h.p_src; + a.p_dst = h.p_dst; + + // assert rank <= kMaxRanks + index_t i = 0; + + index_t perm[kMaxRanks]; + index_t x_shape[kMaxRanks]; + index_t x_stride[kMaxRanks]; + // index_t perm_length[kMaxRanks]; + + for(; i < h.rank; i++) + { + x_shape[i] = h.shape[i]; + perm[i] = h.perm[i]; + } + for(; i < kMaxRanks; i++) + { + x_shape[i] = 1; + perm[i] = i; // will index to len = 1 + } + + index_t stride = 1; + for(index_t j = kMaxRanks - 1; j >= 0; j--) + { + x_stride[j] = stride; + stride *= x_shape[j]; + } + + for(index_t j = 0; j < kMaxRanks; j++) + { + a.perm_length[j] = x_shape[perm[j]]; + a.perm_stride[j] = x_stride[perm[j]]; + } + + a.num_elements = TotalElements(h); + return a; + } + + CK_TILE_HOST static constexpr auto GridSize(GenericPermuteHostArgs h) + { + auto total = TotalElements(h); + auto grids = dim3((total + BlockSize() - 1) / BlockSize()); + // printf("### total:%d, grids:%dx%dx%d\n", total, ); + return grids; + } + + CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + index_t id = blockIdx.x * BlockSize() + threadIdx.x; + + if(id >= kargs.num_elements) + return; + + const auto perm_length = + generate_tuple([&](auto I) { return kargs.perm_length[I]; }, number{}); + const auto perm_stride = + generate_tuple([&](auto I) { return kargs.perm_stride[I]; }, number{}); + + const DataType* p_src = reinterpret_cast(kargs.p_src); + DataType* p_dst = reinterpret_cast(kargs.p_dst); + + const auto src_view_0 = make_naive_tensor_view( + p_src, perm_length, perm_stride, number<1>{}, number<1>{}); + + const auto src_view = transform_tensor_view( + src_view_0, + make_tuple(make_merge_transform(perm_length)), + make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}), + make_tuple(sequence<0>{})); + + auto dst_view_0 = make_naive_tensor_view_packed( + p_dst, perm_length, number<1>{}); + + auto dst_view = transform_tensor_view( + dst_view_0, + make_tuple(make_merge_transform(perm_length)), + make_tuple(typename arithmetic_sequence_gen<0, kMaxRanks, 1>::type{}), + make_tuple(sequence<0>{})); + + // TODO: hard code to vector 1 + using vector_t = thread_buffer; + + const auto src_coord = + make_tensor_coordinate(src_view.get_tensor_descriptor(), array{id}); + const auto dst_coord = + make_tensor_coordinate(dst_view.get_tensor_descriptor(), array{id}); + + // printf("src id:%d, os:%d\n", id, src_coord.get_offset()); + // printf("dst id:%d, os:%d\n", id, dst_coord.get_offset()); + + const vector_t x = src_view.template get_vectorized_elements(src_coord, 0); + dst_view.template set_vectorized_elements(dst_coord, 0, x); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp b/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp new file mode 100644 index 0000000000..e504ed7472 --- /dev/null +++ b/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct GenericPermuteProblem +{ + using DataType = remove_cvref_t; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMaxRanks = kMaxRanks_; + /* KeepLastDim: + * if last dim keep the same? this can help enable vector load + * permute(0, 2, 4, 1, 3, 5) -> true + * permute(0, 3, 2, 1) -> false + */ + static constexpr bool KeepLastDim = KeepLastDim_; + // TODO: not used(?) +}; + +} // namespace ck_tile