From d5fcf10b2905c6bfdb70dfad442bb60d3256cefc Mon Sep 17 00:00:00 2001 From: Emily Martins <65371150+ecamartins@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:20:28 -0600 Subject: [PATCH] Tests for CK tile Permute and MOE Sorting (#2417) * Convert ck-tile 06_permute smoke test to unit tests for fp16, fp8, and fp32 * Apply clang format and update copy right year * Convert ck tile moe sorting example smoke tests to unit tests * fix CMakelists to ensure that permute and moe_sorting are built for gfx9 only. * Remove number prefix from permute and moe_sorting directory names * code cleanup * add missing test cases for fp16 permute * remove unecessary parentheses * Cleanup * Remove uneccessary final nullptr * update copyright and licensing statement in files * Add custom target for permute tests * Add missing new line at end of file for moe sorting CMakelist. * Update MOE sorting tests to account for MOE sorting example updates The ck_tile/13_moe_sorting example was updated to include different cases dependending on whether MOE_SORTING_FMOE_2D_BUF is set. So, the ck_tile tests for MOE sorting were updated to account for these changes. --------- Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> [ROCm/composable_kernel commit: 1fa1c34b7e70939ed1e131edef0e6d7ae6b29d0d] --- test/ck_tile/CMakeLists.txt | 2 + test/ck_tile/moe_sorting/CMakeLists.txt | 15 + test/ck_tile/moe_sorting/moe_sorting_api.cpp | 444 +++++++++++++++ test/ck_tile/moe_sorting/moe_sorting_api.hpp | 33 ++ test/ck_tile/moe_sorting/moe_sorting_fp32.cpp | 538 ++++++++++++++++++ test/ck_tile/permute/CMakeLists.txt | 33 ++ .../alternative_impl/matrix_core_swizzle.cpp | 101 ++++ .../alternative_impl/matrix_core_swizzle.hpp | 20 + .../matrix_core_swizzle_kernel.hpp | 413 ++++++++++++++ test/ck_tile/permute/permute.hpp | 19 + test/ck_tile/permute/permute_fp16.cpp | 29 + test/ck_tile/permute/permute_fp32.cpp | 29 + test/ck_tile/permute/permute_fp8.cpp | 29 + test/ck_tile/permute/permute_utils.inc | 490 ++++++++++++++++ 14 files changed, 2195 insertions(+) create mode 100644 test/ck_tile/moe_sorting/CMakeLists.txt create mode 100644 test/ck_tile/moe_sorting/moe_sorting_api.cpp create mode 100644 test/ck_tile/moe_sorting/moe_sorting_api.hpp create mode 100644 test/ck_tile/moe_sorting/moe_sorting_fp32.cpp create mode 100644 test/ck_tile/permute/CMakeLists.txt create mode 100644 test/ck_tile/permute/alternative_impl/matrix_core_swizzle.cpp create mode 100644 test/ck_tile/permute/alternative_impl/matrix_core_swizzle.hpp create mode 100644 test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp create mode 100644 test/ck_tile/permute/permute.hpp create mode 100644 test/ck_tile/permute/permute_fp16.cpp create mode 100644 test/ck_tile/permute/permute_fp32.cpp create mode 100644 test/ck_tile/permute/permute_fp8.cpp create mode 100644 test/ck_tile/permute/permute_utils.inc diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 0b6fd35988..648fdc7739 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -8,6 +8,8 @@ add_subdirectory(data_type) # Not including these tests as there is a bug on gfx90a and gfx942 # resulting in "GPU core dump" #add_subdirectory(moe_smoothquant) +add_subdirectory(permute) +add_subdirectory(moe_sorting) add_subdirectory(slice_tile) add_subdirectory(batched_transpose) add_subdirectory(smoothquant) diff --git a/test/ck_tile/moe_sorting/CMakeLists.txt b/test/ck_tile/moe_sorting/CMakeLists.txt new file mode 100644 index 0000000000..e360293878 --- /dev/null +++ b/test/ck_tile/moe_sorting/CMakeLists.txt @@ -0,0 +1,15 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + + add_test_executable(test_ck_tile_moe_sorting_fp32 moe_sorting_fp32.cpp moe_sorting_api.cpp) + target_include_directories(test_ck_tile_moe_sorting_fp32 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + + set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + # list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + target_compile_options(test_ck_tile_moe_sorting_fp32 PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS}) + +else() + message(DEBUG "Skipping ck_tile_moe_sorting tests for current target") +endif() diff --git a/test/ck_tile/moe_sorting/moe_sorting_api.cpp b/test/ck_tile/moe_sorting/moe_sorting_api.cpp new file mode 100644 index 0000000000..0e8998e254 --- /dev/null +++ b/test/ck_tile/moe_sorting/moe_sorting_api.cpp @@ -0,0 +1,444 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "moe_sorting_api.hpp" + +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0 +#endif + +#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK +#define MOE_SORTING_SUPPORT_LARGE_TOPK 0 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + +#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr ck_tile::index_t expert_tile = expert_tile_; \ + using ms_problem = \ + ck_tile::MoeSortingProblem; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#else + +#define MOE_SORTING_DISPATCH_( \ + sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + constexpr bool local_expert_masking = local_expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \ + row_, sub_token_onshot_, local_expert_masking_, local_token_) \ + if(row_ % 8 == 0) \ + { \ + MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \ + } \ + else if(row_ % 4 == 0) \ + { \ + MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \ + } \ + else if(row_ % 2 == 0) \ + { \ + MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \ + } + +#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ + if(is_local_token) \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \ + } + +#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ + if(is_sub_token_onshot) \ + { \ + MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \ + } + +#define MOE_SORTING_DISPATCH_EMASK_(row_) \ + if(is_local_expert_masking) \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, false) \ + } + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_DISPATCH(unroll_num_) \ + if(a.num_experts <= 8) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \ + } \ + else if(a.num_experts <= 16) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \ + } \ + else if(a.num_experts <= 32) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \ + } \ + else if(a.num_experts <= 64) \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ + } +#endif + +float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { +#if !MOE_SORTING_USE_EX_KERNEL + if(a.num_experts > 127) + { + printf("lds size exceed, only support experts <127 \n"); + return -1; + } + if(a.moe_buf_bytes % 16) + { + printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes); + return -1; + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64); + switch(smem_io_unroll_num) + { + case(1): { + MOE_SORTING_DISPATCH(1); + } + case(2): { + MOE_SORTING_DISPATCH(2); + } + case(3): { + MOE_SORTING_DISPATCH(3); + } + case(5): { + MOE_SORTING_DISPATCH(5); + } + case(6): { + MOE_SORTING_DISPATCH(6); + } + case(8): { + MOE_SORTING_DISPATCH(8); + } + case(10): { + MOE_SORTING_DISPATCH(10); + } + default: { + MOE_SORTING_DISPATCH(4); + } + } +#else + if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk, t.dispatch_policy) != 0) + { + return moe_sorting_mp(t, a, s); + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts); + auto row_ = sub_token_ / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + bool is_local_expert_masking = t.local_expert_masking; + bool is_local_token = a.p_local_tokens != nullptr; + + MOE_SORTING_DISPATCH_EMASK_(row_); + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif + } + return -1; +} + +#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#if MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#endif + +#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_size = kernel::GetSmemSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + }() + +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + maybe_clear_workspace, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + maybe_clear_workspace, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + maybe_clear_workspace, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + maybe_clear_workspace, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ + } + +#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \ + [&]() { \ + using problem_ = \ + ck_tile::MoeSortingClearWorkspaceProblem; \ + using kernel = ck_tile::MoeSortingClearWorkspaceKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) +{ + bool is_local_token = a.p_local_tokens != nullptr; + if(t.weight_type == "fp32" && t.index_type == "int32") + { + using ms_index_t = ck_tile::index_t; + using ms_weight_type = float; + + auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) { + if(t.clear_workspace_inside_api) + { + if(is_local_token) + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1); + k(s_); + } + else + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1); + k(s_); + } + } + }; + + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } + } + } + return -1; +} + +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy) +{ + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy); +} diff --git a/test/ck_tile/moe_sorting/moe_sorting_api.hpp b/test/ck_tile/moe_sorting/moe_sorting_api.hpp new file mode 100644 index 0000000000..5808d20f6d --- /dev/null +++ b/test/ck_tile/moe_sorting/moe_sorting_api.hpp @@ -0,0 +1,33 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/fused_moe.hpp" + +struct moe_sorting_trait +{ + std::string index_type; + std::string weight_type; // currently always float + bool local_expert_masking; // if mask experts as local expert + bool clear_workspace_inside_api; // if true, no need clear workspace outsize (will take care of + // it inside API) + int dispatch_policy; // 0 - let the API choose kernel for you. 1 - always use single kerenl. 2 - + // always use mp kernel NOTE: moe_sorting_get_workspace_size() need use + // same dispatch_policy value. it will be undefined behavior if ppl using + // different value when get ws and call the kernel +}; + +struct moe_sorting_args : public ck_tile::MoeSortingHostArgs +{ +}; + +// use below API before call moe_sorting() to indicate if need workspace or not +// if return non zero, means need workspace, you need to allocate a GPU buffer +// and set to moe_sorting_args.p_ws +// NOTE: workspace size are required to clear zero before use the API +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy); +float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); +float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp b/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp new file mode 100644 index 0000000000..cc511984fe --- /dev/null +++ b/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp @@ -0,0 +1,538 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "moe_sorting_api.hpp" + +auto create_args(int argc, char* argv[], int index = 0) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "turn CPU validation on (1) or off (0).") + .insert("pr_i", "int32", "index data type. Only int32 is currently supported.") + .insert("pr_w", "fp32", "output weight data type. Only fp32 is currently supported.") + .insert("t", + "128", + "number of input tokens.\n" + "If \"local_t\" presents, this value indicates global concurrency of all ranks.") + .insert( + "local_t", + "-1", + "Number of local input tokens for curent rank.\n" + "This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n" + "This feature is to simulate EP case where where each rank has different tokens.\n" + "Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.") + .insert("e", "8", "number of num_experts") + .insert("k", "4", "topk") + .insert("unit", "32", "unit_size") +#if MOE_SORTING_FMOE_2D_BUF + .insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf") + .insert( + "moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...") +#else + .insert("moe_buf_size", "0", "moe_buf_size") +#endif + .insert("ci", + "1", + "clear workspace inside API or not(if \"0\", require manually clear outside)") + .insert( + "dispatch", + "0", + "dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel") + .insert("local_eid", + "-1", + "a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n" + "please make sure eid is in ascending order!") + .insert("seed", + "-1", + "seed to be used. When set to -1, a random seed will be generated each time " + "invoking this example") + .insert("kname", "0", "prints the kernel name when set to 1") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv, index); + return std::make_tuple(result, arg_parser); +} + +template +void topid_unique_gen( + std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) +{ + size_t total_size = topk * tokens; + std::srand(seed); + std::set unique_set; + IndexType current_v; + for(size_t i = 0; i < total_size; i++) + { + if(i % topk == 0) + { + unique_set.clear(); + } + current_v = std::rand() % num_expert; + while(unique_set.find(current_v) != unique_set.end()) + { + current_v = std::rand() % num_expert; + } + unique_set.insert(current_v); + host_tensor[i] = current_v; + } +} + +template +bool test_moe_sorting(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string index_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + int tokens = args.get_int("t"); + int local_tokens = args.get_int("local_t"); + int num_experts = args.get_int("e"); + int topk = args.get_int("k"); + int seed = args.get_int("seed"); + int unit_size = args.get_int("unit"); +#if MOE_SORTING_FMOE_2D_BUF + int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim"); + int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes"); +#else + int64_t moe_buf_size = static_cast(args.get_uint64("moe_buf_size")); +#endif + int kname = args.get_int("kname"); + int warmup = args.get_int("warmup"); + int repeat = args.get_int("repeat"); + bool clear_inside = args.get_int("ci") != 0; + int dispatch_policy = args.get_int("dispatch"); + + int max_output_ids = + ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); + + if(seed < 0) + { + seed = std::time(nullptr); + } + + if(topk > num_experts) + { + printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n", + topk, + num_experts); + return false; + } + + // if local_tokens == tokens, not local_token, but better avoid this since no meaning for such + // case + bool is_local_token = local_tokens >= 0 && local_tokens < tokens; + + if(local_tokens > tokens) + { + printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens); + return false; + } + + bool local_expert_masking = args.get_str("local_eid") != "-1"; + auto local_expert_masking_host = [&]() { + if(local_expert_masking) + { + auto local_eid = args.get_int_vec("local_eid"); + ck_tile::HostTensor v_{{num_experts}}; + v_.SetZero(); + for(auto eid : local_eid) + { + if(eid >= num_experts) + { + throw std::runtime_error( + "local_eid larger than number of expert, please check"); + } + v_.mData[eid] = 1; + } + return v_; + } + else + return ck_tile::HostTensor{{1}}; + }(); + + // tokens already considered batch size + ck_tile::HostTensor topk_ids_host({tokens, topk}, {topk, 1}); + ck_tile::HostTensor weights_host({tokens, topk}, {topk, 1}); + ck_tile::HostTensor sorted_ids_host({max_output_ids}, {1}); + ck_tile::HostTensor sorted_weights_host({max_output_ids}, {1}); + ck_tile::HostTensor sorted_expert_ids_host({max_output_ids / unit_size}, {1}); + // for simplicity, below buffer allocate 2 dword + ck_tile::HostTensor sorted_id_cnt_host({2}, {1}); +#if MOE_SORTING_FMOE_2D_BUF + ck_tile::HostTensor moe_buf_host( + {static_cast(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim * + moe_buf_elem_bytes}); + auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast(0) + : moe_buf_host.get_element_space_size_in_bytes(); +#else + ck_tile::HostTensor moe_buf_host({moe_buf_size}); + auto moe_buf_bytes = moe_buf_size == 0 ? static_cast(0) + : moe_buf_host.get_element_space_size_in_bytes(); +#endif + + ck_tile::FillUniformDistribution{-.5f, .5f}(weights_host); +#if MOE_SORTING_FMOE_2D_BUF + ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); +#else + ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); +#endif + topid_unique_gen(topk_ids_host.mData, tokens, topk, num_experts, seed); + + ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_expert_ids_dev( + sorted_expert_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem local_expert_masking_dev( + local_expert_masking_host.get_element_space_size_in_bytes()); + + // used for simulating dynamic_tokens for EP case + ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t)); + if(is_local_token) + { + local_tokens_dev.ToDevice(&local_tokens); + } + + topk_ids_dev.ToDevice(topk_ids_host.data()); + weights_dev.ToDevice(weights_host.data()); + if(moe_buf_bytes > 0) + { + moe_buf_dev.ToDevice(moe_buf_host.data()); + } + if(local_expert_masking) + local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); + + // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr + ck_tile::index_t workspace_size = + moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy); + ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); + if(workspace_size != 0 && clear_inside == false) + moe_sorting_ws.SetZero(); // note, clear here!!!! + + moe_sorting_trait trait{ + index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy}; + + moe_sorting_args karg + { + topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr, + is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, + sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), + sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(), + moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size, + num_experts, topk, +#if MOE_SORTING_FMOE_2D_BUF + moe_buf_interm_dim, moe_buf_elem_bytes +#else + static_cast(moe_buf_size * sizeof(float)) +#endif + }; + + ck_tile::stream_config sc{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + warmup, + repeat}; + + auto ms = moe_sorting(trait, karg, sc); + + printf("[%s|%s|%s|%d]tokens:%d", + index_prec.c_str(), + weight_prec.c_str(), + workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"), + dispatch_policy, + tokens); + if(is_local_token) + { + printf("(%d)", local_tokens); + } + printf(", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0); + + if(local_expert_masking) + { + printf("local_eid:%s, ", args.get_str("local_eid").c_str()); + } + + if(moe_buf_bytes > 0) + { +#if MOE_SORTING_FMOE_2D_BUF + printf("moe_buf:%lu(%d,%d), ", + static_cast(moe_buf_bytes), + moe_buf_interm_dim, + moe_buf_elem_bytes); +#else + + printf("moe_buf:%lu, ", static_cast(moe_buf_bytes)); +#endif + } + + if(ms < 0) + printf("not supported\n"); + else + printf("ms:%f, ", ms); + fflush(stdout); + if(ms < 0) + { + return false; + } + + sorted_ids_dev.FromDevice(sorted_ids_host.data()); + sorted_weights_dev.FromDevice(sorted_weights_host.data()); + sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data()); + sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data()); + if(moe_buf_bytes > 0) + { + moe_buf_dev.FromDevice(moe_buf_host.data()); + } + + bool rtn = true; + if(validate) + { + ck_tile::HostTensor sorted_ids_ref({max_output_ids}, {1}); + ck_tile::HostTensor sorted_weights_ref({max_output_ids}, {1}); + ck_tile::HostTensor sorted_expert_ids_ref({max_output_ids / unit_size}, {1}); + + int32_t ref_total_tokens_post_pad = 0; + ck_tile::reference_moe_sorting(topk_ids_host, + weights_host, + local_expert_masking_host, + sorted_ids_ref, + sorted_weights_ref, + sorted_expert_ids_ref, + ref_total_tokens_post_pad, + num_experts, + unit_size, + is_local_token ? local_tokens + : tokens, + local_expert_masking); + printf("total_tokens_post_pad:%d(%d), ", + ref_total_tokens_post_pad, + sorted_id_cnt_host.mData[0]); + if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]) + { + size_t slen = ref_total_tokens_post_pad; + rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}), + sorted_ids_ref.slice({0}, {slen}), + std::string("OUT Error: Incorrect ids!"), + 1e-6, + 1e-6); + rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}), + sorted_weights_ref.slice({0}, {slen}), + std::string("OUT Error: Incorrect w!"), + 1e-6, + 1e-6); + rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}), + sorted_expert_ids_ref.slice({0}, {slen / unit_size}), + std::string("OUT Error: Incorrect eid!"), + 1e-6, + 1e-6); + // if(is_local_token) + { + auto t_ = is_local_token ? local_tokens : tokens; + bool _f = t_ == sorted_id_cnt_host.mData[1]; + rtn &= _f; + if(!_f) + { + printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]); + } + } + } + else + { + printf("(token size not equal!!)"); + rtn = false; + } + + if(moe_buf_bytes) + { +#if MOE_SORTING_FMOE_2D_BUF + ck_tile::HostTensor moe_buf_ref({moe_buf_bytes}); +#else + ck_tile::HostTensor moe_buf_ref({moe_buf_size}); +#endif + rtn &= ck_tile::check_err( + moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); + } + // rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; + } + + printf("valid:%s", rtn ? "y" : "n"); + fflush(stdout); + if(!rtn) + printf(", (%d)", seed); + printf("\n"); + fflush(stdout); + return rtn; +} +template +bool run_test_case(int argc, char* argv[]) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return false; + + return test_moe_sorting(args); +} + +template +bool run_test_cases(std::vector>& test_cases) +{ + bool valid = true; + + for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx) + { + + constexpr int max_num_args = 7; + const int num_args = test_cases[test_idx].size(); + + assert(max_num_args >= num_args && "Invalid number of arguments in test case"); + + char* argv[max_num_args]; + + for(int arg_idx = 0; arg_idx < num_args; ++arg_idx) + { + argv[arg_idx] = test_cases[test_idx][arg_idx].data(); + } + + try + { + valid = valid && run_test_case(num_args, argv); + + if(!valid) + break; + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return false; + } + } + + return valid; +} + +std::vector> create_test_cases() +{ +#if MOE_SORTING_FMOE_2D_BUF + return {{"-t=80", "-e=17", "-moe_buf_interm_dim=16", "-moe_buf_elem_bytes=4"}, + {"-t=111", "-e=117", "-moe_buf_interm_dim=4", "-moe_buf_elem_bytes=4"}, + {"-t=1000", "-e=55", "-moe_buf_interm_dim=1024", "-moe_buf_elem_bytes=1"}, + {"-t=99", "-e=120", "-moe_buf_interm_dim=10244", "-moe_buf_elem_bytes=2"}, + {"-t=175", "-e=64", "-k=8"}, + {"-t=65", "-e=8", "-k=2"}, + {"-t=1", "-e=25"}, + {"-t=31", "-e=19", "-k=15"}, + {"-t=81", "-e=37", "-k=7"}, + {"-t=23", "-e=1", "-k=1"}, + {"-t=127", "-e=99", "-k=19"}, + {"-t=71", "-e=11", "-k=11"}, + {"-t=1", "-e=1", "-k=1"}, + {"-t=99", "-e=2", "-k=1"}, + {"-t=333", "-e=99", "-k=13"}, + {"-t=11", "-e=256", "-k=5"}, + {"-t=64", "-e=455", "-k=8"}, + {"-t=777", "-e=802", "-k=99"}, + {"-t=4097", "-e=906", "-k=51"}, + {"-t=128", "-e=32", "-k=5", "-local_t=6", "-moe_buf_interm_dim=262144"}, + {"-t=13", "-e=64", "-k=3", "-local_eid=4,5,6,7,8,9,10,11"}, + {"-t=99", "-e=33", "-k=9", "-local_eid=6,10,11,15,19"}, + {"-t=80", "-e=99", "-k=10", "-local_eid=0,8,12,33"}, + {"-t=11", "-e=256", "-k=5", "-local_eid=99,110,129"}, + {"-t=128", "-e=128", "-k=6", "-moe_buf_interm_dim=163840", "-moe_buf_elem_bytes=1"}, + {"-t=8192", "-e=32", "-k=5", "-local_t=11", "-moe_buf_interm_dim=163840"}, + {"-t=8192", + "-e=32", + "-k=8", + "-local_t=12", + "-moe_buf_interm_dim=163840", + "-moe_buf_elem_bytes=1"}, + {"-t=8192", "-e=256", "-k=5", "-local_t=13", "-moe_buf_interm_dim=163840"}, + {"-t=8192", "-e=256", "-k=8", "-local_t=8", "-moe_buf_interm_dim=163840"}, + {"-t=163840", + "-e=256", + "-k=8", + "-local_t=4", + "-moe_buf_interm_dim=163840", + "-moe_buf_elem_bytes=4"}, + {"-t=12", "-local_t=3", "-e=256", "-k=5", "-local_eid=9,10,199,145"}, + {"-t=67", "-local_t=9", "-e=555", "-k=5", "-local_eid=19,23,24,25,26,99"}, + {"-t=99", "-local_t=93", "-e=121", "-local_t=4", "-moe_buf_interm_dim=10244"}, + {"-t=536", "-local_t=345", "-e=802", "-k=99"}, + {"-t=331", "-local_t=39", "-e=83", "-k=33"}, + {"-t=765", "-local_t=654", "-e=783", "-k=8"}, + {"-t=23", "-local_t=9", "-e=1", "-k=1"}, + {"-t=7", "-local_t=0", "-e=89", "-k=1", "-local_eid=0,8,12,33"}, + {"-t=61", "-local_t=0", "-e=333", "-k=99", "-local_eid=0,8,12,33"}, + {"-t=133940", + "-local_t=111921", + "-e=256", + "-k=17", + "-local_t=2", + "-moe_buf_interm_dim=133940", + "-moe_buf_elem_bytes=1"}}; + +#else + return {{"-t=80", "-e=17", "-moe_buf_size=16"}, + {"-t=111", "-e=117", "-moe_buf_size=4"}, + {"-t=1000", "-e=55", "-moe_buf_size=1024"}, + {"-t=99", "-e=120", "-moe_buf_size=10244"}, + {"-t=175", "-e=64", "-k=8"}, + {"-t=65", "-e=8", "-k=2"}, + {"-t=1", "-e=25"}, + {"-t=31", "-e=19", "-k=15"}, + {"-t=81", "-e=37", "-k=7"}, + {"-t=23", "-e=1", "-k=1"}, + {"-t=127", "-e=99", "-k=19"}, + {"-t=71", "-e=11", "-k=11"}, + {"-t=1", "-e=1", "-k=1"}, + {"-t=99", "-e=2", "-k=1"}, + {"-t=333", "-e=99", "-k=13"}, + {"-t=11", "-e=256", "-k=5"}, + {"-t=64", "-e=455", "-k=8"}, + {"-t=777", "-e=802", "-k=99"}, + {"-t=4097", "-e=906", "-k=51"}, + {"-t=128", "-e=32", "-k=5", "-moe_buf_size=262144"}, + {"-t=13", "-e=64", "-k=3", "-local_eid=4,5,6,7,8,9,10,11"}, + {"-t=99", "-e=33", "-k=9", "-local_eid=6,10,11,15,19"}, + {"-t=80", "-e=99", "-k=10", "-local_eid=0,8,12,33"}, + {"-t=11", "-e=256", "-k=5", "-local_eid=99,110,129"}, + {"-t=128", "-e=128", "-k=6", "-moe_buf_size=163840"}, + {"-t=8192", "-e=32", "-k=5", "-moe_buf_size=163840"}, + {"-t=8192", "-e=32", "-k=8", "-moe_buf_size=163840"}, + {"-t=8192", "-e=256", "-k=5", "-moe_buf_size=163840"}, + {"-t=8192", "-e=256", "-k=8", "-moe_buf_size=163840"}, + {"-t=163840", "-e=256", "-k=8", "-moe_buf_size=163840"}, + {"-t=12", "-local_t=3", "-e=256", "-k=5", "-local_eid=9,10,199,145"}, + {"-t=67", "-local_t=9", "-e=555", "-k=5", "-local_eid=19,23,24,25,26,99"}, + {"-t=99", "-local_t=93", "-e=121", "-moe_buf_size=10244"}, + {"-t=536", "-local_t=345", "-e=802", "-k=99"}, + {"-t=331", "-local_t=39", "-e=83", "-k=33"}, + {"-t=765", "-local_t=654", "-e=783", "-k=8"}, + {"-t=23", "-local_t=9", "-e=1", "-k=1"}, + {"-t=7", "-local_t=0", "-e=89", "-k=1", "-local_eid=0,8,12,33"}, + {"-t=61", "-local_t=0", "-e=333", "-k=99", "-local_eid=0,8,12,33"}, + {"-t=133940", "-local_t=111921", "-e=256", "-k=17", "-moe_buf_size=133940"}}; +#endif +} + +int main() +{ + std::vector> test_cases = create_test_cases(); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/permute/CMakeLists.txt b/test/ck_tile/permute/CMakeLists.txt new file mode 100644 index 0000000000..7ee55a984d --- /dev/null +++ b/test/ck_tile/permute/CMakeLists.txt @@ -0,0 +1,33 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + + function(add_permute_test TARGET_NAME MAIN_SRC) + add_test_executable(${TARGET_NAME} ${MAIN_SRC}) + + if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL) + set(PERMUTE_USE_ALTERNATIVE_IMPL true) + endif() + + if(PERMUTE_USE_ALTERNATIVE_IMPL) + target_compile_options(${TARGET_NAME} PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL) + target_sources(${TARGET_NAME} PRIVATE alternative_impl/matrix_core_swizzle.cpp) + endif() + + endfunction(add_permute_test TARGET_NAME MAIN_SRC) + + set(CUSTOM_TARGET_NAME test_ck_tile_permute) + + add_custom_target(${CUSTOM_TARGET_NAME}) + + add_permute_test(test_ck_tile_permute_fp16 permute_fp16.cpp) + add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_permute_fp16) + + add_permute_test(test_ck_tile_permute_fp8 permute_fp8.cpp) + add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_permute_fp8) + + add_permute_test(test_ck_tile_permute_fp32 permute_fp32.cpp) + add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_permute_fp32) + +else() + message(DEBUG "Skipping ck_tile_permute tests for current target") +endif() diff --git a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.cpp b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.cpp new file mode 100644 index 0000000000..aedcfac138 --- /dev/null +++ b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.cpp @@ -0,0 +1,101 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "matrix_core_swizzle.hpp" +#include "matrix_core_swizzle_kernel.hpp" + +float matrix_core_swizzle(matrix_core_swizzle_traits t, + matrix_core_swizzle_args a, + const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + if(t.inst.compare("32x32x8") == 0) + { + constexpr int BLOCK_SIZE = 256; + constexpr int NPerBlock = 256; + constexpr int KPerBlock = 128; + constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16; + if(t.permute.compare("0,1,4,2,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,2,4,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,3,4,2,5") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::b_nr_kr_kw_nw_kv; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + } + else if(t.inst.compare("16x16x16") == 0) + { + constexpr int BLOCK_SIZE = 256; + constexpr int NPerBlock = 256; + constexpr int KPerBlock = 128; + constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16; + if(t.permute.compare("0,1,4,2,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,2,4,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,3,4,2,5") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::b_nr_kr_kw_nw_kv; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + } + } + return -1; +} diff --git a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.hpp b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.hpp new file mode 100644 index 0000000000..89dfeda4af --- /dev/null +++ b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle.hpp @@ -0,0 +1,20 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include "matrix_core_swizzle_kernel.hpp" +#include + +struct matrix_core_swizzle_traits +{ + std::string data_type; // fp16 only + std::string inst; // 32x32x8, 16x16x16 + std::string permute; // +}; + +using matrix_core_swizzle_args = matrix_core_swizzle_host_args; + +// host API +float matrix_core_swizzle(matrix_core_swizzle_traits, + matrix_core_swizzle_args, + const ck_tile::stream_config&); diff --git a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp new file mode 100644 index 0000000000..518a9a8889 --- /dev/null +++ b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -0,0 +1,413 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +// if set to 1, slightly more instructions generated to calculate address +#ifndef MERGE_2D_013425 +#define MERGE_2D_013425 0 +#endif + +enum class matrix_core_inst_enum +{ + MFMA_32x32x8_F16 = 0, + MFMA_16x16x16_F16 = 1, +}; + +namespace detail { +template +struct to_warp_gemm; + +template <> +struct to_warp_gemm +{ + using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8; +}; + +template <> +struct to_warp_gemm +{ + using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16; +}; +} // namespace detail +template +using to_warp_gemm_t = typename detail::to_warp_gemm::type; + +// TODO: in below permute pattern, the last 3 dim is within wave +enum class matrix_core_permute_style +{ + permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 + permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 + b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, +}; + +// assume this is B matrix, originally we have batch*n*k +// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 +// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4) +// +// 4(waves) 32(mfma_m lane) +// | | +// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading) +// nr kr | +// nr 4 32 kr 2 8 2(klane) +// +// permute: 0,1,4,2,5,3,6 +// or +// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading) +// permute: 0,1,2,4,5,3,6 +// +// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling +// for simplicity, only consider n/k is multiple of block-size + +// independend host arg with no template +struct matrix_core_swizzle_host_args +{ + const void* p_src; + void* p_dst; + int32_t batch; + int32_t n; + int32_t k; +}; + +// NOTE: this kernel could follow the style of generic permute kernel +// but here we pass in fixed layout as template arg and generate different kernel instance +// purposely +template +struct matrix_core_swizzle_kernel +{ + using karg = matrix_core_swizzle_host_args; + using harg = matrix_core_swizzle_host_args; + + static constexpr int BLOCK_SIZE = BLOCK_SIZE_; + static constexpr int WavesPerBlock_N = 4; + static constexpr int WavesPerBlock_K = 1; + static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE); + static constexpr int NPerBlock = NPerBlock_; + static constexpr int KPerBlock = KPerBlock_; + static constexpr matrix_core_permute_style pstyle = pstyle_; + static constexpr matrix_core_inst_enum Inst = Inst_; + + static constexpr ck_tile::index_t Alignment = 8; + karg a; + dim3 grids; + + using WarpGemm = to_warp_gemm_t; + + __host__ matrix_core_swizzle_kernel(harg h) + { + a = h; + ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock; + ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock; + grids = dim3(ks, ns, h.batch); + } + + __host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; } + + __host__ void operator()(const ck_tile::stream_config& s) const + { + ck_tile::kentry<<>>(a); + } + + struct kernel + { + __device__ static constexpr auto get_src_dist() + { + using namespace ck_tile; + constexpr index_t K2 = Alignment; + constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t N1 = BLOCK_SIZE / get_warp_size(); + + static_assert(NPerBlock % (N1 * N2) == 0); + static_assert(KPerBlock % (K1 * K2) == 0); + + constexpr index_t K0 = KPerBlock / (K1 * K2); + constexpr index_t N0 = NPerBlock / (N1 * N2); + + // clang-format off + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>,// 0 + // 1 2 3 4 5 6 + tuple, sequence, sequence, sequence, sequence, sequence>, + + // N1 K1 N2 + tuple, sequence<5, 3>>, + tuple, sequence<0, 0>>, + + // N0 K0 K2 + sequence<1, 4, 6>, + sequence<0, 0, 0>>{}); + // clang-format on + } + __device__ static constexpr auto get_dst_dist() + { + using namespace ck_tile; + constexpr index_t K2 = Alignment; + constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t N1 = BLOCK_SIZE / get_warp_size(); + + static_assert(NPerBlock % (N1 * N2) == 0); + static_assert(KPerBlock % (K1 * K2) == 0); + + constexpr index_t K0 = KPerBlock / (K1 * K2); + constexpr index_t N0 = NPerBlock / (N1 * N2); + + if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2) + { + // clang-format off + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>,// 0 + // 1 2 3 4 5 6 + tuple, sequence, sequence, sequence, sequence, sequence>, + + // N1 K1 N2 + tuple, sequence<4, 5>>, + tuple, sequence<0, 0>>, + + // N0 K0 K2 + sequence<1, 2, 6>, + sequence<0, 0, 0>>{}); + // clang-format on + } + else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2) + { + // clang-format off + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>,// 0 + // 1 2 3 4 5 6 + tuple, sequence, sequence, sequence, sequence, sequence>, + + // N1 K1 N2 + tuple, sequence<4, 5>>, + tuple, sequence<0, 0>>, + + // N0 K0 K2 + sequence<1, 3, 6>, + sequence<0, 0, 0>>{}); + // clang-format on + } + else + { + // clang-format off + // b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten + constexpr index_t Kv = Alignment; + constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + + static_assert(KPerBlock % (K1 * K2) == 0); + constexpr index_t Nr = NPerBlock / Nw; + constexpr index_t Kr = KPerBlock / (Kv * Kw); + + constexpr index_t Nr_p = WavesPerBlock_N; + constexpr index_t Kr_p = WavesPerBlock_K; + constexpr index_t Nr_y = Nr / Nr_p; + constexpr index_t Kr_y = Kr / Kr_p; + + return make_static_tile_distribution( +#if MERGE_2D_013425 + tile_distribution_encoding< + sequence<1>,// 0 R + // major 1 2 + // minor 0 1 2 0 1 2 3 + tuple, sequence>, // H + + // Nr_p, Kr_p Kw Nw + tuple, sequence<2, 1>>, // p major + tuple, sequence<2, 2>>, // p minor + + // Nr_y Kr_y Kv + sequence<1, 2, 2>, // Y major + sequence<0, 0, 3>>{}); // y minor +#else + tile_distribution_encoding< + sequence<1>,// 0 R + // major 1 2 3 + // minor 0 1 0 1 0 1 2 + tuple, sequence, sequence>, // H + + // Nr_p, Kr_p Kw Nw + tuple, sequence<3, 3>>, // p major + tuple, sequence<0, 1>>, // p minor + + // Nr_y Kr_y Kv + sequence<1, 2, 3>, // Y major + sequence<0, 0, 2>>{}); // y minor +#endif + // clang-format on + } + } + + __device__ void operator()(karg a_) + { + using namespace ck_tile; + index_t i_k = blockIdx.x; + index_t i_n = blockIdx.y; + index_t i_b = blockIdx.z; + + constexpr index_t k2 = Alignment; + constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t n1 = BLOCK_SIZE / get_warp_size(); + const index_t k0 = a_.k / (k1 * k2); + const index_t n0 = a_.n / (n1 * n2); + + constexpr index_t k2_tile = Alignment; + constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size(); + constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile); + constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile); + + const fp16_t* p_src = reinterpret_cast(a_.p_src) + i_b * a_.k * a_.n; + fp16_t* p_dst = reinterpret_cast(a_.p_dst) + i_b * a_.k * a_.n; + + const auto src_view = [&]() { + const auto tmp = make_naive_tensor_view_packed( + p_src, + make_tuple(n0, n1, n2, k0, k1, k2), + number{}); // control vector load + return tmp; + }(); + + const auto src_window = make_tile_window(src_view, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + {i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0}, + get_src_dist()); + + auto dst_view = [&]() { + if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2) + { + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(n0, k0, n1, k1, n2, k2), + number{}); // control vector load + return tmp; + } + else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2) + { + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(n0, n1, k0, k1, n2, k2), + number{}); // control vector load + return tmp; + } + else + { +#if MERGE_2D_013425 + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + // constexpr index_t waveflatten = kw*nw*kv; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(nr, kr, number{}, number{}, number{}), + number{}); // control vector load + auto tmp_1 = transform_tensor_view( + tmp, + make_tuple( + make_merge_transform(make_tuple(nr, number{})), + make_merge_transform(make_tuple(kr, number{}, number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return tmp_1; +#else + // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t waveflatten = kw * nw * kv; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(nr, kr, waveflatten), + number{}); // control vector load + return tmp; +#endif + } + }(); + + auto dst_window = [&]() { + if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2) + { + return make_tile_window(dst_view, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + {i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0}, + get_dst_dist()); + } + else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2) + { + return make_tile_window(dst_view, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + {i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0}, + get_dst_dist()); + } + else + { +#if MERGE_2D_013425 + // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv + return make_tile_window(dst_view, + make_tuple(number{}, number{}), + {i_n * NPerBlock, i_k * KPerBlock}, + get_dst_dist()); +#else + // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t waveflatten_tile = kw * nw * kv; + constexpr index_t nr_tile = NPerBlock / nw; + constexpr index_t kr_tile = KPerBlock / (kw * kv); + return make_tile_window(dst_view, + make_tuple(number{}, + number{}, + number{}), + {i_n * nr_tile, i_k * kr_tile, 0}, + get_dst_dist()); +#endif + } + }(); + + // actual load store + auto src_tile = load_tile(src_window); + + // now we only swap the distribution from src to dst, no extra movement occurs + auto dst_tile = make_static_distributed_tensor(get_dst_dist()); + dst_tile.get_thread_buffer() = src_tile.get_thread_buffer(); + + // final store + store_tile(dst_window, dst_tile); + } + }; +}; diff --git a/test/ck_tile/permute/permute.hpp b/test/ck_tile/permute/permute.hpp new file mode 100644 index 0000000000..5724b0f316 --- /dev/null +++ b/test/ck_tile/permute/permute.hpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/permute.hpp" +#include + +struct permute_traits +{ + std::string data_type; +}; + +using permute_args = ck_tile::GenericPermuteHostArgs; + +// host API +float permute(permute_traits, permute_args, const ck_tile::stream_config&); diff --git a/test/ck_tile/permute/permute_fp16.cpp b/test/ck_tile/permute/permute_fp16.cpp new file mode 100644 index 0000000000..24781261ef --- /dev/null +++ b/test/ck_tile/permute/permute_fp16.cpp @@ -0,0 +1,29 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "permute.hpp" +#include "ck_tile/host.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef PERMUTE_USE_ALTERNATIVE_IMPL +#include "alternative_impl/matrix_core_swizzle.hpp" +#endif + +#include "permute_utils.inc" + +int main() +{ + std::vector> test_cases = create_test_cases_fp16(); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/permute/permute_fp32.cpp b/test/ck_tile/permute/permute_fp32.cpp new file mode 100644 index 0000000000..2ece7c20bb --- /dev/null +++ b/test/ck_tile/permute/permute_fp32.cpp @@ -0,0 +1,29 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "permute.hpp" +#include "ck_tile/host.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef PERMUTE_USE_ALTERNATIVE_IMPL +#include "alternative_impl/matrix_core_swizzle.hpp" +#endif + +#include "permute_utils.inc" + +int main() +{ + std::vector> test_cases = create_test_cases("fp32"); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/permute/permute_fp8.cpp b/test/ck_tile/permute/permute_fp8.cpp new file mode 100644 index 0000000000..e8ae5d0410 --- /dev/null +++ b/test/ck_tile/permute/permute_fp8.cpp @@ -0,0 +1,29 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "permute.hpp" +#include "ck_tile/host.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef PERMUTE_USE_ALTERNATIVE_IMPL +#include "alternative_impl/matrix_core_swizzle.hpp" +#endif + +#include "permute_utils.inc" + +int main() +{ + std::vector> test_cases = create_test_cases("fp8"); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/permute/permute_utils.inc b/test/ck_tile/permute/permute_utils.inc new file mode 100644 index 0000000000..6b8cb86b53 --- /dev/null +++ b/test/ck_tile/permute/permute_utils.inc @@ -0,0 +1,490 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace detail { +template +struct to_integer_type; + +template <> +struct to_integer_type<4> +{ + using type = int32_t; +}; +template <> +struct to_integer_type<2> +{ + using type = int16_t; +}; +template <> +struct to_integer_type<1> +{ + using type = int8_t; +}; +} // namespace detail + +template +using to_integer_type = typename detail::to_integer_type::type; + +// host API (shoule come from codegen) +float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp8") == 0) + { + using DataType = ck_tile::fp8_t; + using PipelineProblem = ck_tile::GenericPermuteProblem; + using Kernel = ck_tile::GenericPermute; + + auto kargs = Kernel::MakeKargs(a); + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + else if(t.data_type.compare("fp16") == 0) + { + using DataType = ck_tile::half_t; + using PipelineProblem = ck_tile::GenericPermuteProblem; + using Kernel = ck_tile::GenericPermute; + + auto kargs = Kernel::MakeKargs(a); + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + else if(t.data_type.compare("fp32") == 0) + { + using DataType = float; + using PipelineProblem = ck_tile::GenericPermuteProblem; + using Kernel = ck_tile::GenericPermute; + + auto kargs = Kernel::MakeKargs(a); + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + + return 0; +} + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[], int start_index = 0) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)") + .insert("shape", "2,3,4", "the shape of the input tensor") + .insert("perm", "2,1,0", "permute perm") + .insert("kname", "0", "t to 1 will print kernel name") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv, start_index); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +// "1,2,3,4" -> vector{1,2,3,4} +std::vector decode_vec(std::string q_val) +{ +#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) + std::string::size_type pos = 0; + std::vector v; + while(true) + { + auto found = q_val.find(',', pos); + ck_tile::index_t n = + _S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos)); + v.push_back(n); + if(found == std::string::npos) + { + break; + } + pos = found + 1; + } + return v; +#undef _S2I_ +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + + auto shape = decode_vec(arg_parser.get_str("shape")); + auto perm = decode_vec(arg_parser.get_str("perm")); + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + int seed = arg_parser.get_int("seed"); + + assert(shape.size() == perm.size()); + ck_tile::index_t rank = perm.size(); + if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks) + { + printf("rank %d permute is not support yet\n", rank); + return false; + } + + ck_tile::HostTensor x(shape); + ck_tile::FillUniformDistributionIntegerValue{-15, 15, seed}(x); + + std::vector y_shape = [&]() { + std::vector tmp(rank, 0); + // std::cout << "@@@@" << tmp << std::endl; + for(int i = 0; i < static_cast(rank); i++) + { + // std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" << + // static_cast(rank) + // << std::endl; + tmp[i] = shape[perm[i]]; + } + // std::cout << "@@@" << tmp << std::endl; + return tmp; + }(); + + ck_tile::HostTensor y(y_shape); + + ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x.data()); + + std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm + << std::endl; + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat}; + float ave_time = 0.f; + auto run_permute = [&]() { + permute_traits t; + t.data_type = data_type; + + permute_args a; + a.p_src = x_buf.GetDeviceBuffer(); + a.p_dst = y_buf.GetDeviceBuffer(); + a.rank = rank; + std::copy(shape.begin(), shape.end(), a.shape); + std::copy(perm.begin(), perm.end(), a.perm); + + return permute(t, a, stream_config); + }; +#ifdef PERMUTE_USE_ALTERNATIVE_IMPL + // batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 + if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") || + arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") || + arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))) + { + if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")) + { + // b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + matrix_core_swizzle_traits t; + t.data_type = data_type; + t.permute = arg_parser.get_str("perm"); + + matrix_core_swizzle_args a; + a.p_src = x_buf.GetDeviceBuffer(); + a.p_dst = y_buf.GetDeviceBuffer(); + a.batch = shape[0]; + + auto nr = shape[1]; + auto nw = shape[2]; + auto kr = shape[3]; + auto kw = shape[4]; + auto kv = shape[5]; + a.n = nr * nw; + a.k = kr * kw * kv; + if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0) + { + t.inst = "16x16x16"; + std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0) + { + t.inst = "32x32x8"; + std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else + { + ave_time = run_permute(); + } + } + else + { + matrix_core_swizzle_traits t; + t.data_type = data_type; + t.permute = arg_parser.get_str("perm"); + + matrix_core_swizzle_args a; + a.p_src = x_buf.GetDeviceBuffer(); + a.p_dst = y_buf.GetDeviceBuffer(); + a.batch = shape[0]; + a.n = shape[1] * shape[2] * shape[3]; + a.k = shape[4] * shape[5] * shape[6]; + if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 && + shape[4] % 8 == 0 && shape[1] % 2 == 0) + { + // 32x32x8 inst + // perm=0,1,4,2,5,3,6 + // y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8) + // shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8) + + t.inst = "32x32x8"; + std::cout << ", matrix_core_swizzle_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 && + shape[4] % 4 == 0 && shape[1] % 4 == 0) + { + // 16x16x16 inst + // perm=0,1,4,2,5,3,6 + // y_shape=*,4x,4x,4,4,16,8 + // shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8) + t.inst = "16x16x16"; + std::cout << ", matrix_core_swizzle_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else + { + ave_time = run_permute(); + } + } + } + else +#endif + { + ave_time = run_permute(); + } + std::cout << ", time:" << ave_time << "ms" << std::flush; + + bool pass = true; + if(do_validation) + { + reference_permute(x, y, perm); + + ck_tile::HostTensor y_dev(y.get_lengths()); + + y_buf.FromDevice(y_dev.data()); + + pass = std::equal( + y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) { + using itype = to_integer_type; + itype i_d = ck_tile::bit_cast(d); + itype i_h = ck_tile::bit_cast(h); + return i_d == i_h; + }); + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + + std::cout << std::endl; + + return pass; +} + +template +bool run_test_case(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + + if(!result) + return false; + + return run(arg_parser); +} + +template +bool run_test_cases(std::vector>& test_cases) +{ + bool valid = true; + constexpr int num_args = 6; + char* argv[num_args]; + + for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx) + { + assert(test_cases[test_idx].size() == num_args && + "invalid number of arguments in test case"); + + for(int arg_idx = 0; arg_idx < num_args; ++arg_idx) + { + argv[arg_idx] = test_cases[test_idx][arg_idx].data(); + } + + valid = valid && run_test_case(num_args, argv); + + if(!valid) + break; + } + + return valid; +} + +std::vector> create_test_cases(const std::string prec) +{ + return { + {"-prec=" + prec, "-shape=3,8", "-perm=1,0", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=" + prec, "-shape=48,6,8", "-perm=2,1,0", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=" + prec, "-shape=24,128,3", "-perm=0,2,1", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=" + prec, "-shape=4,10,7,6", "-perm=0,2,3,1", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=" + prec, "-shape=8,24,36,10", "-perm=3,1,2,0", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=" + prec, "-shape=8,1,36,4", "-perm=2,1,0,3", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=" + prec, + "-shape=5,10,16,2,36,4", + "-perm=4,5,2,1,0,3", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=" + prec, + "-shape=2,32,8,3,6,2,5,4", + "-perm=5,2,4,7,1,6,3,0", + "-v=1", + "-warmup=0", + "-repeat=1"}}; +} + +std::vector> create_test_cases_fp16() +{ + return {{"-prec=fp16", + "-shape=3,6,4,32,16,2,8", + "-perm=0,1,4,2,5,3,6", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=fp16", + "-shape=5,10,4,32,8,2,8", + "-perm=0,1,4,2,5,3,6", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=fp16", + "-shape=3,8,4,16,16,4,8", + "-perm=0,1,4,2,5,3,6", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=fp16", + "-shape=3,6,4,32,16,2,8", + "-perm=0,1,2,4,5,3,6", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=fp16", + "-shape=5,10,4,32,8,2,8", + "-perm=0,1,2,4,5,3,6", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=fp16", + "-shape=3,8,4,16,16,4,8", + "-perm=0,1,2,4,5,3,6", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=fp16", + "-shape=2,8,16,8,4,8", + "-perm=0,1,3,4,2,5", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=fp16", + "-shape=1,24,32,16,2,8", + "-perm=0,1,3,4,2,5", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=fp16", "-shape=3,8", "-perm=1,0", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=fp16", "-shape=48,6,8", "-perm=2,1,0", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=fp16", "-shape=24,128,3", "-perm=0,2,1", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=fp16", "-shape=4,10,7,6", "-perm=0,2,3,1", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=fp16", "-shape=8,24,36,10", "-perm=3,1,2,0", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=fp16", "-shape=8,1,36,4", "-perm=2,1,0,3", "-v=1", "-warmup=0", "-repeat=1"}, + {"-prec=fp16", + "-shape=5,10,16,2,36,4", + "-perm=4,5,2,1,0,3", + "-v=1", + "-warmup=0", + "-repeat=1"}, + {"-prec=fp16", + "-shape=2,32,8,3,6,2,5,4", + "-perm=5,2,4,7,1,6,3,0", + "-v=1", + "-warmup=0", + "-repeat=1"}}; +}