Tests for CK tile Permute and MOE Sorting (#2417)

* Convert ck-tile 06_permute smoke test to unit tests for fp16, fp8, and fp32

* Apply clang format and update copy right year

* Convert ck tile moe sorting example smoke tests to unit tests

* fix CMakelists to ensure that permute and moe_sorting are built for gfx9 only.

* Remove number prefix from permute and moe_sorting directory names

* code cleanup

* add missing test cases for fp16 permute

* remove unecessary parentheses

* Cleanup

* Remove uneccessary final nullptr

* update copyright and licensing statement in files

* Add custom target for permute tests

* Add missing new line at end of file for moe sorting CMakelist.

* Update MOE sorting tests to account for MOE sorting example updates

The ck_tile/13_moe_sorting example was updated to include different
cases dependending on whether MOE_SORTING_FMOE_2D_BUF is set. So,
the ck_tile tests for MOE sorting were updated to account for these
changes.

---------

Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>

[ROCm/composable_kernel commit: 1fa1c34b7e]
This commit is contained in:
Emily Martins
2025-07-21 12:20:28 -06:00
committed by GitHub
parent 682c158b28
commit d5fcf10b29
14 changed files with 2195 additions and 0 deletions

View File

@@ -8,6 +8,8 @@ add_subdirectory(data_type)
# Not including these tests as there is a bug on gfx90a and gfx942
# resulting in "GPU core dump"
#add_subdirectory(moe_smoothquant)
add_subdirectory(permute)
add_subdirectory(moe_sorting)
add_subdirectory(slice_tile)
add_subdirectory(batched_transpose)
add_subdirectory(smoothquant)

View File

@@ -0,0 +1,15 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_test_executable(test_ck_tile_moe_sorting_fp32 moe_sorting_fp32.cpp moe_sorting_api.cpp)
target_include_directories(test_ck_tile_moe_sorting_fp32 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(test_ck_tile_moe_sorting_fp32 PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_moe_sorting tests for current target")
endif()

View File

@@ -0,0 +1,444 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "moe_sorting_api.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0
#endif
#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK
#define MOE_SORTING_SUPPORT_LARGE_TOPK 0
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
using ms_problem = \
ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#else
#define MOE_SORTING_DISPATCH_( \
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
}
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(is_local_token) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
#endif
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
{
#if !MOE_SORTING_USE_EX_KERNEL
if(a.num_experts > 127)
{
printf("lds size exceed, only support experts <127 \n");
return -1;
}
if(a.moe_buf_bytes % 16)
{
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
return -1;
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
switch(smem_io_unroll_num)
{
case(1): {
MOE_SORTING_DISPATCH(1);
}
case(2): {
MOE_SORTING_DISPATCH(2);
}
case(3): {
MOE_SORTING_DISPATCH(3);
}
case(5): {
MOE_SORTING_DISPATCH(5);
}
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
default: {
MOE_SORTING_DISPATCH(4);
}
}
#else
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk, t.dispatch_policy) != 0)
{
return moe_sorting_mp(t, a, s);
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts);
auto row_ = sub_token_ / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
bool is_local_token = a.p_local_tokens != nullptr;
MOE_SORTING_DISPATCH_EMASK_(row_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#endif
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_size = kernel::GetSmemSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
[&]() { \
using problem_ = \
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
bool is_local_token = a.p_local_tokens != nullptr;
if(t.weight_type == "fp32" && t.index_type == "int32")
{
using ms_index_t = ck_tile::index_t;
using ms_weight_type = float;
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
if(t.clear_workspace_inside_api)
{
if(is_local_token)
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
k(s_);
}
else
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
k(s_);
}
}
};
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
}
}
}
return -1;
}
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy)
{
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
}

View File

@@ -0,0 +1,33 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/fused_moe.hpp"
struct moe_sorting_trait
{
std::string index_type;
std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
bool clear_workspace_inside_api; // if true, no need clear workspace outsize (will take care of
// it inside API)
int dispatch_policy; // 0 - let the API choose kernel for you. 1 - always use single kerenl. 2 -
// always use mp kernel NOTE: moe_sorting_get_workspace_size() need use
// same dispatch_policy value. it will be undefined behavior if ppl using
// different value when get ws and call the kernel
};
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
{
};
// use below API before call moe_sorting() to indicate if need workspace or not
// if return non zero, means need workspace, you need to allocate a GPU buffer
// and set to moe_sorting_args.p_ws
// NOTE: workspace size are required to clear zero before use the API
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy);
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);

View File

@@ -0,0 +1,538 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <set>
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
auto create_args(int argc, char* argv[], int index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "turn CPU validation on (1) or off (0).")
.insert("pr_i", "int32", "index data type. Only int32 is currently supported.")
.insert("pr_w", "fp32", "output weight data type. Only fp32 is currently supported.")
.insert("t",
"128",
"number of input tokens.\n"
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
.insert(
"local_t",
"-1",
"Number of local input tokens for curent rank.\n"
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
"This feature is to simulate EP case where where each rank has different tokens.\n"
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
.insert("e", "8", "number of num_experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
#if MOE_SORTING_FMOE_2D_BUF
.insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf")
.insert(
"moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...")
#else
.insert("moe_buf_size", "0", "moe_buf_size")
#endif
.insert("ci",
"1",
"clear workspace inside API or not(if \"0\", require manually clear outside)")
.insert(
"dispatch",
"0",
"dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel")
.insert("local_eid",
"-1",
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
"please make sure eid is in ascending order!")
.insert("seed",
"-1",
"seed to be used. When set to -1, a random seed will be generated each time "
"invoking this example")
.insert("kname", "0", "prints the kernel name when set to 1")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv, index);
return std::make_tuple(result, arg_parser);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool test_moe_sorting(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int local_tokens = args.get_int("local_t");
int num_experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
int unit_size = args.get_int("unit");
#if MOE_SORTING_FMOE_2D_BUF
int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim");
int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes");
#else
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
#endif
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
bool clear_inside = args.get_int("ci") != 0;
int dispatch_policy = args.get_int("dispatch");
int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
if(seed < 0)
{
seed = std::time(nullptr);
}
if(topk > num_experts)
{
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
topk,
num_experts);
return false;
}
// if local_tokens == tokens, not local_token, but better avoid this since no meaning for such
// case
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
if(local_tokens > tokens)
{
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
return false;
}
bool local_expert_masking = args.get_str("local_eid") != "-1";
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
{
auto local_eid = args.get_int_vec("local_eid");
ck_tile::HostTensor<IndexType> v_{{num_experts}};
v_.SetZero();
for(auto eid : local_eid)
{
if(eid >= num_experts)
{
throw std::runtime_error(
"local_eid larger than number of expert, please check");
}
v_.mData[eid] = 1;
}
return v_;
}
else
return ck_tile::HostTensor<IndexType>{{1}};
}();
// tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
// for simplicity, below buffer allocate 2 dword
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_host(
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
moe_buf_elem_bytes});
auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#else
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#endif
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
#else
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
#endif
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_dev(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
// used for simulating dynamic_tokens for EP case
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.ToDevice(moe_buf_host.data());
}
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size =
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0 && clear_inside == false)
moe_sorting_ws.SetZero(); // note, clear here!!!!
moe_sorting_trait trait{
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
moe_sorting_args karg
{
topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size,
num_experts, topk,
#if MOE_SORTING_FMOE_2D_BUF
moe_buf_interm_dim, moe_buf_elem_bytes
#else
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
#endif
};
ck_tile::stream_config sc{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s|%s|%d]tokens:%d",
index_prec.c_str(),
weight_prec.c_str(),
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
dispatch_policy,
tokens);
if(is_local_token)
{
printf("(%d)", local_tokens);
}
printf(", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0);
if(local_expert_masking)
{
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
}
if(moe_buf_bytes > 0)
{
#if MOE_SORTING_FMOE_2D_BUF
printf("moe_buf:%lu(%d,%d), ",
static_cast<uint64_t>(moe_buf_bytes),
moe_buf_interm_dim,
moe_buf_elem_bytes);
#else
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
#endif
}
if(ms < 0)
printf("not supported\n");
else
printf("ms:%f, ", ms);
fflush(stdout);
if(ms < 0)
{
return false;
}
sorted_ids_dev.FromDevice(sorted_ids_host.data());
sorted_weights_dev.FromDevice(sorted_weights_host.data());
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
local_expert_masking_host,
sorted_ids_ref,
sorted_weights_ref,
sorted_expert_ids_ref,
ref_total_tokens_post_pad,
num_experts,
unit_size,
is_local_token ? local_tokens
: tokens,
local_expert_masking);
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0])
{
size_t slen = ref_total_tokens_post_pad;
rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}),
sorted_ids_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect ids!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}),
sorted_weights_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}),
sorted_expert_ids_ref.slice({0}, {slen / unit_size}),
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
// if(is_local_token)
{
auto t_ = is_local_token ? local_tokens : tokens;
bool _f = t_ == sorted_id_cnt_host.mData[1];
rtn &= _f;
if(!_f)
{
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
}
}
}
else
{
printf("(token size not equal!!)");
rtn = false;
}
if(moe_buf_bytes)
{
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
#else
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
#endif
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}
// rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
printf("valid:%s", rtn ? "y" : "n");
fflush(stdout);
if(!rtn)
printf(", (%d)", seed);
printf("\n");
fflush(stdout);
return rtn;
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool run_test_case(int argc, char* argv[])
{
auto [result, args] = create_args(argc, argv);
if(!result)
return false;
return test_moe_sorting<WeightType, IndexType>(args);
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
constexpr int max_num_args = 7;
const int num_args = test_cases[test_idx].size();
assert(max_num_args >= num_args && "Invalid number of arguments in test case");
char* argv[max_num_args];
for(int arg_idx = 0; arg_idx < num_args; ++arg_idx)
{
argv[arg_idx] = test_cases[test_idx][arg_idx].data();
}
try
{
valid = valid && run_test_case<WeightType, IndexType>(num_args, argv);
if(!valid)
break;
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return false;
}
}
return valid;
}
std::vector<std::vector<std::string>> create_test_cases()
{
#if MOE_SORTING_FMOE_2D_BUF
return {{"-t=80", "-e=17", "-moe_buf_interm_dim=16", "-moe_buf_elem_bytes=4"},
{"-t=111", "-e=117", "-moe_buf_interm_dim=4", "-moe_buf_elem_bytes=4"},
{"-t=1000", "-e=55", "-moe_buf_interm_dim=1024", "-moe_buf_elem_bytes=1"},
{"-t=99", "-e=120", "-moe_buf_interm_dim=10244", "-moe_buf_elem_bytes=2"},
{"-t=175", "-e=64", "-k=8"},
{"-t=65", "-e=8", "-k=2"},
{"-t=1", "-e=25"},
{"-t=31", "-e=19", "-k=15"},
{"-t=81", "-e=37", "-k=7"},
{"-t=23", "-e=1", "-k=1"},
{"-t=127", "-e=99", "-k=19"},
{"-t=71", "-e=11", "-k=11"},
{"-t=1", "-e=1", "-k=1"},
{"-t=99", "-e=2", "-k=1"},
{"-t=333", "-e=99", "-k=13"},
{"-t=11", "-e=256", "-k=5"},
{"-t=64", "-e=455", "-k=8"},
{"-t=777", "-e=802", "-k=99"},
{"-t=4097", "-e=906", "-k=51"},
{"-t=128", "-e=32", "-k=5", "-local_t=6", "-moe_buf_interm_dim=262144"},
{"-t=13", "-e=64", "-k=3", "-local_eid=4,5,6,7,8,9,10,11"},
{"-t=99", "-e=33", "-k=9", "-local_eid=6,10,11,15,19"},
{"-t=80", "-e=99", "-k=10", "-local_eid=0,8,12,33"},
{"-t=11", "-e=256", "-k=5", "-local_eid=99,110,129"},
{"-t=128", "-e=128", "-k=6", "-moe_buf_interm_dim=163840", "-moe_buf_elem_bytes=1"},
{"-t=8192", "-e=32", "-k=5", "-local_t=11", "-moe_buf_interm_dim=163840"},
{"-t=8192",
"-e=32",
"-k=8",
"-local_t=12",
"-moe_buf_interm_dim=163840",
"-moe_buf_elem_bytes=1"},
{"-t=8192", "-e=256", "-k=5", "-local_t=13", "-moe_buf_interm_dim=163840"},
{"-t=8192", "-e=256", "-k=8", "-local_t=8", "-moe_buf_interm_dim=163840"},
{"-t=163840",
"-e=256",
"-k=8",
"-local_t=4",
"-moe_buf_interm_dim=163840",
"-moe_buf_elem_bytes=4"},
{"-t=12", "-local_t=3", "-e=256", "-k=5", "-local_eid=9,10,199,145"},
{"-t=67", "-local_t=9", "-e=555", "-k=5", "-local_eid=19,23,24,25,26,99"},
{"-t=99", "-local_t=93", "-e=121", "-local_t=4", "-moe_buf_interm_dim=10244"},
{"-t=536", "-local_t=345", "-e=802", "-k=99"},
{"-t=331", "-local_t=39", "-e=83", "-k=33"},
{"-t=765", "-local_t=654", "-e=783", "-k=8"},
{"-t=23", "-local_t=9", "-e=1", "-k=1"},
{"-t=7", "-local_t=0", "-e=89", "-k=1", "-local_eid=0,8,12,33"},
{"-t=61", "-local_t=0", "-e=333", "-k=99", "-local_eid=0,8,12,33"},
{"-t=133940",
"-local_t=111921",
"-e=256",
"-k=17",
"-local_t=2",
"-moe_buf_interm_dim=133940",
"-moe_buf_elem_bytes=1"}};
#else
return {{"-t=80", "-e=17", "-moe_buf_size=16"},
{"-t=111", "-e=117", "-moe_buf_size=4"},
{"-t=1000", "-e=55", "-moe_buf_size=1024"},
{"-t=99", "-e=120", "-moe_buf_size=10244"},
{"-t=175", "-e=64", "-k=8"},
{"-t=65", "-e=8", "-k=2"},
{"-t=1", "-e=25"},
{"-t=31", "-e=19", "-k=15"},
{"-t=81", "-e=37", "-k=7"},
{"-t=23", "-e=1", "-k=1"},
{"-t=127", "-e=99", "-k=19"},
{"-t=71", "-e=11", "-k=11"},
{"-t=1", "-e=1", "-k=1"},
{"-t=99", "-e=2", "-k=1"},
{"-t=333", "-e=99", "-k=13"},
{"-t=11", "-e=256", "-k=5"},
{"-t=64", "-e=455", "-k=8"},
{"-t=777", "-e=802", "-k=99"},
{"-t=4097", "-e=906", "-k=51"},
{"-t=128", "-e=32", "-k=5", "-moe_buf_size=262144"},
{"-t=13", "-e=64", "-k=3", "-local_eid=4,5,6,7,8,9,10,11"},
{"-t=99", "-e=33", "-k=9", "-local_eid=6,10,11,15,19"},
{"-t=80", "-e=99", "-k=10", "-local_eid=0,8,12,33"},
{"-t=11", "-e=256", "-k=5", "-local_eid=99,110,129"},
{"-t=128", "-e=128", "-k=6", "-moe_buf_size=163840"},
{"-t=8192", "-e=32", "-k=5", "-moe_buf_size=163840"},
{"-t=8192", "-e=32", "-k=8", "-moe_buf_size=163840"},
{"-t=8192", "-e=256", "-k=5", "-moe_buf_size=163840"},
{"-t=8192", "-e=256", "-k=8", "-moe_buf_size=163840"},
{"-t=163840", "-e=256", "-k=8", "-moe_buf_size=163840"},
{"-t=12", "-local_t=3", "-e=256", "-k=5", "-local_eid=9,10,199,145"},
{"-t=67", "-local_t=9", "-e=555", "-k=5", "-local_eid=19,23,24,25,26,99"},
{"-t=99", "-local_t=93", "-e=121", "-moe_buf_size=10244"},
{"-t=536", "-local_t=345", "-e=802", "-k=99"},
{"-t=331", "-local_t=39", "-e=83", "-k=33"},
{"-t=765", "-local_t=654", "-e=783", "-k=8"},
{"-t=23", "-local_t=9", "-e=1", "-k=1"},
{"-t=7", "-local_t=0", "-e=89", "-k=1", "-local_eid=0,8,12,33"},
{"-t=61", "-local_t=0", "-e=333", "-k=99", "-local_eid=0,8,12,33"},
{"-t=133940", "-local_t=111921", "-e=256", "-k=17", "-moe_buf_size=133940"}};
#endif
}
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases();
return !run_test_cases<float, ck_tile::index_t>(test_cases);
}

View File

@@ -0,0 +1,33 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
function(add_permute_test TARGET_NAME MAIN_SRC)
add_test_executable(${TARGET_NAME} ${MAIN_SRC})
if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL)
set(PERMUTE_USE_ALTERNATIVE_IMPL true)
endif()
if(PERMUTE_USE_ALTERNATIVE_IMPL)
target_compile_options(${TARGET_NAME} PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL)
target_sources(${TARGET_NAME} PRIVATE alternative_impl/matrix_core_swizzle.cpp)
endif()
endfunction(add_permute_test TARGET_NAME MAIN_SRC)
set(CUSTOM_TARGET_NAME test_ck_tile_permute)
add_custom_target(${CUSTOM_TARGET_NAME})
add_permute_test(test_ck_tile_permute_fp16 permute_fp16.cpp)
add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_permute_fp16)
add_permute_test(test_ck_tile_permute_fp8 permute_fp8.cpp)
add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_permute_fp8)
add_permute_test(test_ck_tile_permute_fp32 permute_fp32.cpp)
add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_permute_fp32)
else()
message(DEBUG "Skipping ck_tile_permute tests for current target")
endif()

View File

@@ -0,0 +1,101 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "matrix_core_swizzle.hpp"
#include "matrix_core_swizzle_kernel.hpp"
float matrix_core_swizzle(matrix_core_swizzle_traits t,
matrix_core_swizzle_args a,
const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
if(t.inst.compare("32x32x8") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
else if(t.inst.compare("16x16x16") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
}
return -1;
}

View File

@@ -0,0 +1,20 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "matrix_core_swizzle_kernel.hpp"
#include <string>
struct matrix_core_swizzle_traits
{
std::string data_type; // fp16 only
std::string inst; // 32x32x8, 16x16x16
std::string permute; //
};
using matrix_core_swizzle_args = matrix_core_swizzle_host_args;
// host API
float matrix_core_swizzle(matrix_core_swizzle_traits,
matrix_core_swizzle_args,
const ck_tile::stream_config&);

View File

@@ -0,0 +1,413 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
// if set to 1, slightly more instructions generated to calculate address
#ifndef MERGE_2D_013425
#define MERGE_2D_013425 0
#endif
enum class matrix_core_inst_enum
{
MFMA_32x32x8_F16 = 0,
MFMA_16x16x16_F16 = 1,
};
namespace detail {
template <matrix_core_inst_enum>
struct to_warp_gemm;
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_32x32x8_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8;
};
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16;
};
} // namespace detail
template <matrix_core_inst_enum Inst>
using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type;
// TODO: in below permute pattern, the last 3 dim is within wave
enum class matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
};
// assume this is B matrix, originally we have batch*n*k
// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4)
//
// 4(waves) 32(mfma_m lane)
// | |
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading)
// nr kr |
// nr 4 32 kr 2 8 2(klane)
//
// permute: 0,1,4,2,5,3,6
// or
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading)
// permute: 0,1,2,4,5,3,6
//
// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling
// for simplicity, only consider n/k is multiple of block-size
// independend host arg with no template
struct matrix_core_swizzle_host_args
{
const void* p_src;
void* p_dst;
int32_t batch;
int32_t n;
int32_t k;
};
// NOTE: this kernel could follow the style of generic permute kernel
// but here we pass in fixed layout as template arg and generate different kernel instance
// purposely
template <int BLOCK_SIZE_ = 256,
int NPerBlock_ = 256,
int KPerBlock_ = 128,
matrix_core_permute_style pstyle_ =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2,
matrix_core_inst_enum Inst_ = matrix_core_inst_enum::MFMA_32x32x8_F16>
struct matrix_core_swizzle_kernel
{
using karg = matrix_core_swizzle_host_args;
using harg = matrix_core_swizzle_host_args;
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = 4;
static constexpr int WavesPerBlock_K = 1;
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
static constexpr int NPerBlock = NPerBlock_;
static constexpr int KPerBlock = KPerBlock_;
static constexpr matrix_core_permute_style pstyle = pstyle_;
static constexpr matrix_core_inst_enum Inst = Inst_;
static constexpr ck_tile::index_t Alignment = 8;
karg a;
dim3 grids;
using WarpGemm = to_warp_gemm_t<Inst>;
__host__ matrix_core_swizzle_kernel(harg h)
{
a = h;
ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock;
ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock;
grids = dim3(ks, ns, h.batch);
}
__host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; }
__host__ void operator()(const ck_tile::stream_config& s) const
{
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
}
struct kernel
{
__device__ static constexpr auto get_src_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<N2>, sequence<K0>, sequence<K1>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<5, 3>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 4, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
__device__ static constexpr auto get_dst_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<K0>, sequence<N1>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<3>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 2, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<K0>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 3, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
else
{
// clang-format off
// b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
#if MERGE_2D_013425
tile_distribution_encoding<
sequence<1>,// 0 R
// major 1 2
// minor 0 1 2 0 1 2 3
tuple<sequence<Nr_y, Nr_p, Nw>, sequence<Kr_y, Kr_p, Kw, Kv>>, // H
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<2, 1>>, // p major
tuple<sequence<1 , 1>, sequence<2, 2>>, // p minor
// Nr_y Kr_y Kv
sequence<1, 2, 2>, // Y major
sequence<0, 0, 3>>{}); // y minor
#else
tile_distribution_encoding<
sequence<1>,// 0 R
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>, // H
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<3, 3>>, // p major
tuple<sequence<1 , 1>, sequence<0, 1>>, // p minor
// Nr_y Kr_y Kv
sequence<1, 2, 3>, // Y major
sequence<0, 0, 2>>{}); // y minor
#endif
// clang-format on
}
}
__device__ void operator()(karg a_)
{
using namespace ck_tile;
index_t i_k = blockIdx.x;
index_t i_n = blockIdx.y;
index_t i_b = blockIdx.z;
constexpr index_t k2 = Alignment;
constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1 = BLOCK_SIZE / get_warp_size();
const index_t k0 = a_.k / (k1 * k2);
const index_t n0 = a_.n / (n1 * n2);
constexpr index_t k2_tile = Alignment;
constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size();
constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile);
constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile);
const fp16_t* p_src = reinterpret_cast<const fp16_t*>(a_.p_src) + i_b * a_.k * a_.n;
fp16_t* p_dst = reinterpret_cast<fp16_t*>(a_.p_dst) + i_b * a_.k * a_.n;
const auto src_view = [&]() {
const auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_src,
make_tuple(n0, n1, n2, k0, k1, k2),
number<Alignment>{}); // control vector load
return tmp;
}();
const auto src_window = make_tile_window(src_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<n2_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0},
get_src_dist());
auto dst_view = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, k0, n1, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, n1, k0, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
else
{
#if MERGE_2D_013425
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
// constexpr index_t waveflatten = kw*nw*kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, number<kw>{}, number<nw>{}, number<kv>{}),
number<Alignment>{}); // control vector load
auto tmp_1 = transform_tensor_view(
tmp,
make_tuple(
make_merge_transform(make_tuple(nr, number<nw>{})),
make_merge_transform(make_tuple(kr, number<kw>{}, number<kv>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1;
#else
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten = kw * nw * kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, waveflatten),
number<Alignment>{}); // control vector load
return tmp;
#endif
}
}();
auto dst_window = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<k0_tile>{},
number<n1_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0},
get_dst_dist());
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0},
get_dst_dist());
}
else
{
#if MERGE_2D_013425
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist());
#else
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten_tile = kw * nw * kv;
constexpr index_t nr_tile = NPerBlock / nw;
constexpr index_t kr_tile = KPerBlock / (kw * kv);
return make_tile_window(dst_view,
make_tuple(number<nr_tile>{},
number<kr_tile>{},
number<waveflatten_tile>{}),
{i_n * nr_tile, i_k * kr_tile, 0},
get_dst_dist());
#endif
}
}();
// actual load store
auto src_tile = load_tile(src_window);
// now we only swap the distribution from src to dst, no extra movement occurs
auto dst_tile = make_static_distributed_tensor<fp16_t>(get_dst_dist());
dst_tile.get_thread_buffer() = src_tile.get_thread_buffer();
// final store
store_tile(dst_window, dst_tile);
}
};
};

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/permute.hpp"
#include <string>
struct permute_traits
{
std::string data_type;
};
using permute_args = ck_tile::GenericPermuteHostArgs;
// host API
float permute(permute_traits, permute_args, const ck_tile::stream_config&);

View File

@@ -0,0 +1,29 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include <array>
#include <cassert>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
#include "permute_utils.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases_fp16();
return !run_test_cases<ck_tile::half_t>(test_cases);
}

View File

@@ -0,0 +1,29 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include <array>
#include <cassert>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
#include "permute_utils.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("fp32");
return !run_test_cases<float>(test_cases);
}

View File

@@ -0,0 +1,29 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include <array>
#include <cassert>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
#include "permute_utils.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("fp8");
return !run_test_cases<ck_tile::fp8_t>(test_cases);
}

View File

@@ -0,0 +1,490 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace detail {
template <int bytes>
struct to_integer_type;
template <>
struct to_integer_type<4>
{
using type = int32_t;
};
template <>
struct to_integer_type<2>
{
using type = int16_t;
};
template <>
struct to_integer_type<1>
{
using type = int8_t;
};
} // namespace detail
template <int bytes>
using to_integer_type = typename detail::to_integer_type<bytes>::type;
// host API (shoule come from codegen)
float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp8") == 0)
{
using DataType = ck_tile::fp8_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp16") == 0)
{
using DataType = ck_tile::half_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp32") == 0)
{
using DataType = float;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[], int start_index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)")
.insert("shape", "2,3,4", "the shape of the input tensor")
.insert("perm", "2,1,0", "permute perm")
.insert("kname", "0", "t to 1 will print kernel name")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv, start_index);
return std::make_tuple(result, arg_parser);
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
// "1,2,3,4" -> vector{1,2,3,4}
std::vector<ck_tile::index_t> decode_vec(std::string q_val)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
std::string::size_type pos = 0;
std::vector<ck_tile::index_t> v;
while(true)
{
auto found = q_val.find(',', pos);
ck_tile::index_t n =
_S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos));
v.push_back(n);
if(found == std::string::npos)
{
break;
}
pos = found + 1;
}
return v;
#undef _S2I_
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto shape = decode_vec(arg_parser.get_str("shape"));
auto perm = decode_vec(arg_parser.get_str("perm"));
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
int seed = arg_parser.get_int("seed");
assert(shape.size() == perm.size());
ck_tile::index_t rank = perm.size();
if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks)
{
printf("rank %d permute is not support yet\n", rank);
return false;
}
ck_tile::HostTensor<DataType> x(shape);
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
// std::cout << "@@@@" << tmp << std::endl;
for(int i = 0; i < static_cast<int>(rank); i++)
{
// std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" <<
// static_cast<int>(rank)
// << std::endl;
tmp[i] = shape[perm[i]];
}
// std::cout << "@@@" << tmp << std::endl;
return tmp;
}();
ck_tile::HostTensor<DataType> y(y_shape);
ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes());
x_buf.ToDevice(x.data());
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm
<< std::endl;
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
stream_warmup,
stream_repeat};
float ave_time = 0.f;
auto run_permute = [&]() {
permute_traits t;
t.data_type = data_type;
permute_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.rank = rank;
std::copy(shape.begin(), shape.end(), a.shape);
std::copy(perm.begin(), perm.end(), a.perm);
return permute(t, a, stream_config);
};
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")))
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
auto nr = shape[1];
auto nw = shape[2];
auto kr = shape[3];
auto kw = shape[4];
auto kv = shape[5];
a.n = nr * nw;
a.k = kr * kw * kv;
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
else
{
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
a.n = shape[1] * shape[2] * shape[3];
a.k = shape[4] * shape[5] * shape[6];
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
shape[4] % 8 == 0 && shape[1] % 2 == 0)
{
// 32x32x8 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
shape[4] % 4 == 0 && shape[1] % 4 == 0)
{
// 16x16x16 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,4x,4x,4,4,16,8
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
}
else
#endif
{
ave_time = run_permute();
}
std::cout << ", time:" << ave_time << "ms" << std::flush;
bool pass = true;
if(do_validation)
{
reference_permute(x, y, perm);
ck_tile::HostTensor<DataType> y_dev(y.get_lengths());
y_buf.FromDevice(y_dev.data());
pass = std::equal(
y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) {
using itype = to_integer_type<sizeof(DataType)>;
itype i_d = ck_tile::bit_cast<itype>(d);
itype i_h = ck_tile::bit_cast<itype>(h);
return i_d == i_h;
});
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::endl;
return pass;
}
template <typename DataType>
bool run_test_case(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
return run<DataType>(arg_parser);
}
template <typename DataType>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
constexpr int num_args = 6;
char* argv[num_args];
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
assert(test_cases[test_idx].size() == num_args &&
"invalid number of arguments in test case");
for(int arg_idx = 0; arg_idx < num_args; ++arg_idx)
{
argv[arg_idx] = test_cases[test_idx][arg_idx].data();
}
valid = valid && run_test_case<DataType>(num_args, argv);
if(!valid)
break;
}
return valid;
}
std::vector<std::vector<std::string>> create_test_cases(const std::string prec)
{
return {
{"-prec=" + prec, "-shape=3,8", "-perm=1,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=48,6,8", "-perm=2,1,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=24,128,3", "-perm=0,2,1", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=4,10,7,6", "-perm=0,2,3,1", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=8,24,36,10", "-perm=3,1,2,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=8,1,36,4", "-perm=2,1,0,3", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec,
"-shape=5,10,16,2,36,4",
"-perm=4,5,2,1,0,3",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=" + prec,
"-shape=2,32,8,3,6,2,5,4",
"-perm=5,2,4,7,1,6,3,0",
"-v=1",
"-warmup=0",
"-repeat=1"}};
}
std::vector<std::vector<std::string>> create_test_cases_fp16()
{
return {{"-prec=fp16",
"-shape=3,6,4,32,16,2,8",
"-perm=0,1,4,2,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=5,10,4,32,8,2,8",
"-perm=0,1,4,2,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=3,8,4,16,16,4,8",
"-perm=0,1,4,2,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=3,6,4,32,16,2,8",
"-perm=0,1,2,4,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=5,10,4,32,8,2,8",
"-perm=0,1,2,4,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=3,8,4,16,16,4,8",
"-perm=0,1,2,4,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=2,8,16,8,4,8",
"-perm=0,1,3,4,2,5",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=1,24,32,16,2,8",
"-perm=0,1,3,4,2,5",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16", "-shape=3,8", "-perm=1,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=48,6,8", "-perm=2,1,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=24,128,3", "-perm=0,2,1", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=4,10,7,6", "-perm=0,2,3,1", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=8,24,36,10", "-perm=3,1,2,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=8,1,36,4", "-perm=2,1,0,3", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16",
"-shape=5,10,16,2,36,4",
"-perm=4,5,2,1,0,3",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=2,32,8,3,6,2,5,4",
"-perm=5,2,4,7,1,6,3,0",
"-v=1",
"-warmup=0",
"-repeat=1"}};
}