[CK_Tile] Refactor MOE Sorting and Smoothquant ctests to gtests (#2596)

* refactor moe_sorting ctests to use gtest framework

* Refactor ctests for smoothquant to gtests

* fix clang format to use version 18

* Print local_eid in MOE sorting gtests

* Remove extra space in smoothquant output
This commit is contained in:
Emily Martins
2025-08-14 11:54:57 -06:00
committed by GitHub
parent 7f14772406
commit 70dce4e0c6
16 changed files with 2025 additions and 864 deletions

View File

@@ -1,14 +1,19 @@
# Currently ck_tile is only built on gfx90a, gfx942 and gfx950
if(GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx950" OR GPU_TARGETS MATCHES "gfx90a")
add_test_executable(test_ck_tile_moe_sorting_fp32 moe_sorting_fp32.cpp moe_sorting_api.cpp)
target_include_directories(test_ck_tile_moe_sorting_fp32 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
function(add_moe_sorting_test EXECUTABLE USE_2D_BUF)
add_gtest_executable(${EXECUTABLE} test_moe_sorting.cpp moe_sorting_api.cpp)
target_include_directories(${EXECUTABLE} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(test_ck_tile_moe_sorting_fp32 PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -DMOE_SORTING_FMOE_2D_BUF=${USE_2D_BUF})
target_compile_options(${EXECUTABLE} PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
endfunction(add_moe_sorting_test EXECUTABLE USE_2D_BUF)
add_moe_sorting_test(test_ck_tile_moe_sorting_2d_buf 1)
add_moe_sorting_test(test_ck_tile_moe_sorting 0)
else()
message(DEBUG "Skipping ck_tile_moe_sorting tests for current target")

View File

@@ -1,544 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <set>
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
auto create_args(int argc, char* argv[], int index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "turn CPU validation on (1) or off (0).")
.insert("pr_i", "int32", "index data type. Only int32 is currently supported.")
.insert("pr_w", "fp32", "output weight data type. Only fp32 is currently supported.")
.insert("t",
"128",
"number of input tokens.\n"
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
.insert(
"local_t",
"-1",
"Number of local input tokens for curent rank.\n"
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
"This feature is to simulate EP case where where each rank has different tokens.\n"
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
.insert("e", "8", "number of num_experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
#if MOE_SORTING_FMOE_2D_BUF
.insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf")
.insert(
"moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...")
#else
.insert("moe_buf_size", "0", "moe_buf_size")
#endif
.insert("ci",
"1",
"clear workspace inside API or not(if \"0\", require manually clear outside)")
.insert(
"dispatch",
"0",
"dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel")
.insert("local_eid",
"-1",
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
"please make sure eid is in ascending order!")
.insert("seed",
"-1",
"seed to be used. When set to -1, a random seed will be generated each time "
"invoking this example")
.insert("kname", "0", "prints the kernel name when set to 1")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv, index);
return std::make_tuple(result, arg_parser);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool test_moe_sorting(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int local_tokens = args.get_int("local_t");
int num_experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
int unit_size = args.get_int("unit");
#if MOE_SORTING_FMOE_2D_BUF
int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim");
int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes");
#else
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
#endif
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
bool clear_inside = args.get_int("ci") != 0;
int dispatch_policy = args.get_int("dispatch");
int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
if(seed < 0)
{
seed = std::time(nullptr);
}
if(topk > num_experts)
{
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
topk,
num_experts);
return false;
}
// if local_tokens == tokens, not local_token, but better avoid this since no meaning for such
// case
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
if(local_tokens > tokens)
{
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
return false;
}
bool local_expert_masking = args.get_str("local_eid") != "-1";
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
{
auto local_eid = args.get_int_vec("local_eid");
ck_tile::HostTensor<IndexType> v_{{num_experts}};
v_.SetZero();
for(auto eid : local_eid)
{
if(eid >= num_experts)
{
throw std::runtime_error(
"local_eid larger than number of expert, please check");
}
v_.mData[eid] = 1;
}
return v_;
}
else
return ck_tile::HostTensor<IndexType>{{1}};
}();
// tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
// for simplicity, below buffer allocate 2 dword
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_host(
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
moe_buf_elem_bytes});
auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#else
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#endif
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
#else
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
#endif
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_dev(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
// used for simulating dynamic_tokens for EP case
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.ToDevice(moe_buf_host.data());
}
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size =
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0 && clear_inside == false)
moe_sorting_ws.SetZero(); // note, clear here!!!!
moe_sorting_trait trait{
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
num_experts,
topk,
#if MOE_SORTING_FMOE_2D_BUF
moe_buf_interm_dim,
moe_buf_elem_bytes
#else
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
#endif
};
ck_tile::stream_config sc{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s|%s|%d]tokens:%d",
index_prec.c_str(),
weight_prec.c_str(),
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
dispatch_policy,
tokens);
if(is_local_token)
{
printf("(%d)", local_tokens);
}
printf(", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0);
if(local_expert_masking)
{
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
}
if(moe_buf_bytes > 0)
{
#if MOE_SORTING_FMOE_2D_BUF
printf("moe_buf:%lu(%d,%d), ",
static_cast<uint64_t>(moe_buf_bytes),
moe_buf_interm_dim,
moe_buf_elem_bytes);
#else
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
#endif
}
if(ms < 0)
printf("not supported\n");
else
printf("ms:%f, ", ms);
fflush(stdout);
if(ms < 0)
{
return false;
}
sorted_ids_dev.FromDevice(sorted_ids_host.data());
sorted_weights_dev.FromDevice(sorted_weights_host.data());
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
local_expert_masking_host,
sorted_ids_ref,
sorted_weights_ref,
sorted_expert_ids_ref,
ref_total_tokens_post_pad,
num_experts,
unit_size,
is_local_token ? local_tokens
: tokens,
local_expert_masking);
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0])
{
size_t slen = ref_total_tokens_post_pad;
rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}),
sorted_ids_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect ids!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}),
sorted_weights_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}),
sorted_expert_ids_ref.slice({0}, {slen / unit_size}),
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
// if(is_local_token)
{
auto t_ = is_local_token ? local_tokens : tokens;
bool _f = t_ == sorted_id_cnt_host.mData[1];
rtn &= _f;
if(!_f)
{
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
}
}
}
else
{
printf("(token size not equal!!)");
rtn = false;
}
if(moe_buf_bytes)
{
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
#else
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
#endif
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}
// rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
printf("valid:%s", rtn ? "y" : "n");
fflush(stdout);
if(!rtn)
printf(", (%d)", seed);
printf("\n");
fflush(stdout);
return rtn;
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool run_test_case(int argc, char* argv[])
{
auto [result, args] = create_args(argc, argv);
if(!result)
return false;
return test_moe_sorting<WeightType, IndexType>(args);
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
constexpr int max_num_args = 7;
const int num_args = test_cases[test_idx].size();
assert(max_num_args >= num_args && "Invalid number of arguments in test case");
char* argv[max_num_args];
for(int arg_idx = 0; arg_idx < num_args; ++arg_idx)
{
argv[arg_idx] = test_cases[test_idx][arg_idx].data();
}
try
{
valid = valid && run_test_case<WeightType, IndexType>(num_args, argv);
if(!valid)
break;
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return false;
}
}
return valid;
}
std::vector<std::vector<std::string>> create_test_cases()
{
#if MOE_SORTING_FMOE_2D_BUF
return {{"-t=80", "-e=17", "-moe_buf_interm_dim=16", "-moe_buf_elem_bytes=4"},
{"-t=111", "-e=117", "-moe_buf_interm_dim=4", "-moe_buf_elem_bytes=4"},
{"-t=1000", "-e=55", "-moe_buf_interm_dim=1024", "-moe_buf_elem_bytes=1"},
{"-t=99", "-e=120", "-moe_buf_interm_dim=10244", "-moe_buf_elem_bytes=2"},
{"-t=175", "-e=64", "-k=8"},
{"-t=65", "-e=8", "-k=2"},
{"-t=1", "-e=25"},
{"-t=31", "-e=19", "-k=15"},
{"-t=81", "-e=37", "-k=7"},
{"-t=23", "-e=1", "-k=1"},
{"-t=127", "-e=99", "-k=19"},
{"-t=71", "-e=11", "-k=11"},
{"-t=1", "-e=1", "-k=1"},
{"-t=99", "-e=2", "-k=1"},
{"-t=333", "-e=99", "-k=13"},
{"-t=11", "-e=256", "-k=5"},
{"-t=64", "-e=455", "-k=8"},
{"-t=777", "-e=802", "-k=99"},
{"-t=4097", "-e=906", "-k=51"},
{"-t=128", "-e=32", "-k=5", "-local_t=6", "-moe_buf_interm_dim=262144"},
{"-t=13", "-e=64", "-k=3", "-local_eid=4,5,6,7,8,9,10,11"},
{"-t=99", "-e=33", "-k=9", "-local_eid=6,10,11,15,19"},
{"-t=80", "-e=99", "-k=10", "-local_eid=0,8,12,33"},
{"-t=11", "-e=256", "-k=5", "-local_eid=99,110,129"},
{"-t=128", "-e=128", "-k=6", "-moe_buf_interm_dim=163840", "-moe_buf_elem_bytes=1"},
{"-t=8192", "-e=32", "-k=5", "-local_t=11", "-moe_buf_interm_dim=163840"},
{"-t=8192",
"-e=32",
"-k=8",
"-local_t=12",
"-moe_buf_interm_dim=163840",
"-moe_buf_elem_bytes=1"},
{"-t=8192", "-e=256", "-k=5", "-local_t=13", "-moe_buf_interm_dim=163840"},
{"-t=8192", "-e=256", "-k=8", "-local_t=8", "-moe_buf_interm_dim=163840"},
{"-t=163840",
"-e=256",
"-k=8",
"-local_t=4",
"-moe_buf_interm_dim=163840",
"-moe_buf_elem_bytes=4"},
{"-t=12", "-local_t=3", "-e=256", "-k=5", "-local_eid=9,10,199,145"},
{"-t=67", "-local_t=9", "-e=555", "-k=5", "-local_eid=19,23,24,25,26,99"},
{"-t=99", "-local_t=93", "-e=121", "-local_t=4", "-moe_buf_interm_dim=10244"},
{"-t=536", "-local_t=345", "-e=802", "-k=99"},
{"-t=331", "-local_t=39", "-e=83", "-k=33"},
{"-t=765", "-local_t=654", "-e=783", "-k=8"},
{"-t=23", "-local_t=9", "-e=1", "-k=1"},
{"-t=7", "-local_t=0", "-e=89", "-k=1", "-local_eid=0,8,12,33"},
{"-t=61", "-local_t=0", "-e=333", "-k=99", "-local_eid=0,8,12,33"},
{"-t=133940",
"-local_t=111921",
"-e=256",
"-k=17",
"-local_t=2",
"-moe_buf_interm_dim=133940",
"-moe_buf_elem_bytes=1"}};
#else
return {{"-t=80", "-e=17", "-moe_buf_size=16"},
{"-t=111", "-e=117", "-moe_buf_size=4"},
{"-t=1000", "-e=55", "-moe_buf_size=1024"},
{"-t=99", "-e=120", "-moe_buf_size=10244"},
{"-t=175", "-e=64", "-k=8"},
{"-t=65", "-e=8", "-k=2"},
{"-t=1", "-e=25"},
{"-t=31", "-e=19", "-k=15"},
{"-t=81", "-e=37", "-k=7"},
{"-t=23", "-e=1", "-k=1"},
{"-t=127", "-e=99", "-k=19"},
{"-t=71", "-e=11", "-k=11"},
{"-t=1", "-e=1", "-k=1"},
{"-t=99", "-e=2", "-k=1"},
{"-t=333", "-e=99", "-k=13"},
{"-t=11", "-e=256", "-k=5"},
{"-t=64", "-e=455", "-k=8"},
{"-t=777", "-e=802", "-k=99"},
{"-t=4097", "-e=906", "-k=51"},
{"-t=128", "-e=32", "-k=5", "-moe_buf_size=262144"},
{"-t=13", "-e=64", "-k=3", "-local_eid=4,5,6,7,8,9,10,11"},
{"-t=99", "-e=33", "-k=9", "-local_eid=6,10,11,15,19"},
{"-t=80", "-e=99", "-k=10", "-local_eid=0,8,12,33"},
{"-t=11", "-e=256", "-k=5", "-local_eid=99,110,129"},
{"-t=128", "-e=128", "-k=6", "-moe_buf_size=163840"},
{"-t=8192", "-e=32", "-k=5", "-moe_buf_size=163840"},
{"-t=8192", "-e=32", "-k=8", "-moe_buf_size=163840"},
{"-t=8192", "-e=256", "-k=5", "-moe_buf_size=163840"},
{"-t=8192", "-e=256", "-k=8", "-moe_buf_size=163840"},
{"-t=163840", "-e=256", "-k=8", "-moe_buf_size=163840"},
{"-t=12", "-local_t=3", "-e=256", "-k=5", "-local_eid=9,10,199,145"},
{"-t=67", "-local_t=9", "-e=555", "-k=5", "-local_eid=19,23,24,25,26,99"},
{"-t=99", "-local_t=93", "-e=121", "-moe_buf_size=10244"},
{"-t=536", "-local_t=345", "-e=802", "-k=99"},
{"-t=331", "-local_t=39", "-e=83", "-k=33"},
{"-t=765", "-local_t=654", "-e=783", "-k=8"},
{"-t=23", "-local_t=9", "-e=1", "-k=1"},
{"-t=7", "-local_t=0", "-e=89", "-k=1", "-local_eid=0,8,12,33"},
{"-t=61", "-local_t=0", "-e=333", "-k=99", "-local_eid=0,8,12,33"},
{"-t=133940", "-local_t=111921", "-e=256", "-k=17", "-moe_buf_size=133940"}};
#endif
}
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases();
return !run_test_cases<float, ck_tile::index_t>(test_cases);
}

View File

@@ -0,0 +1,14 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_moe_sorting_types.hpp"
#include "test_moe_sorting_util.hpp"
#include "gtest/gtest.h"
#define TEST_SUITE_NAME TestCkTileMoeSorting
TYPED_TEST_SUITE(TestCkTileMoeSorting, KernelTypesMoeSorting);
#include "test_moe_sorting_cases.inc"
#undef TEST_SUITE_NAME

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "ck_tile/host.hpp"
#include "gtest/gtest.h"
using KernelTypesMoeSorting = ::testing::Types<std::tuple<float, ck_tile::index_t>>;

View File

@@ -0,0 +1,356 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <set>
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
void print_vector(std::vector<int>& data)
{
for(const auto& x : data)
{
std::cout << x << ",";
}
std::cout << " ";
}
template <typename Tuple>
class TestCkTileMoeSorting : public ::testing::Test
{
protected:
using WeightType = std::tuple_element_t<0, Tuple>;
using IndexType = std::tuple_element_t<1, Tuple>;
void RunSingle(int tokens,
int local_tokens,
int num_experts,
int topk,
int unit_size,
std::vector<int>& local_eid,
#if MOE_SORTING_FMOE_2D_BUF
int moe_buf_interm_dim,
int moe_buf_elem_bytes)
#else
int64_t moe_buf_size)
#endif
{
std::string index_prec = get_precision_string<IndexType>();
std::string weight_prec = get_precision_string<WeightType>();
bool clear_inside = true;
int dispatch_policy = 0;
int max_output_ids = ck_tile::integer_least_multiple(
topk * tokens + num_experts * unit_size - topk, unit_size);
int seed = 42; // Fixed seed for testing reproducibility
if(topk > num_experts)
{
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
topk,
num_experts);
EXPECT_TRUE(false);
}
// if local_tokens == tokens, not local_token, but better avoid this since no meaning for
// such case
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
if(local_tokens > tokens)
{
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
EXPECT_TRUE(false);
}
bool local_expert_masking = !local_eid.empty();
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
{
// auto local_eid = args.get_int_vec("local_eid");
ck_tile::HostTensor<IndexType> v_{{num_experts}};
v_.SetZero();
for(auto eid : local_eid)
{
if(eid >= num_experts)
{
throw std::runtime_error(
"local_eid larger than number of expert, please check");
}
v_.mData[eid] = 1;
}
return v_;
}
else
return ck_tile::HostTensor<IndexType>{{1}};
}();
// tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
// for simplicity, below buffer allocate 2 dword
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_host(
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
moe_buf_elem_bytes});
auto moe_buf_bytes = moe_buf_interm_dim == 0
? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#else
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#endif
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
#else
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
#endif
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weights_dev(
sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_dev(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
// used for simulating dynamic_tokens for EP case
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.ToDevice(moe_buf_host.data());
}
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size =
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0 && clear_inside == false)
moe_sorting_ws.SetZero(); // note, clear here!!!!
moe_sorting_trait trait{
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
num_experts,
topk,
#if MOE_SORTING_FMOE_2D_BUF
moe_buf_interm_dim,
moe_buf_elem_bytes
#else
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
#endif
};
ck_tile::stream_config sc{nullptr, false};
auto ret_val = moe_sorting(trait, karg, sc);
printf("[%s|%s|%s|%d]tokens:%d",
index_prec.c_str(),
weight_prec.c_str(),
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
dispatch_policy,
tokens);
if(is_local_token)
{
printf("(%d)", local_tokens);
}
printf(
", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0);
if(local_expert_masking)
{
printf("local_eid:");
print_vector(local_eid);
}
if(moe_buf_bytes > 0)
{
#if MOE_SORTING_FMOE_2D_BUF
printf("moe_buf:%lu(%d,%d), ",
static_cast<uint64_t>(moe_buf_bytes),
moe_buf_interm_dim,
moe_buf_elem_bytes);
#else
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
#endif
}
if(ret_val < 0)
{
printf("not supported\n");
fflush(stdout);
EXPECT_TRUE(false);
}
sorted_ids_dev.FromDevice(sorted_ids_host.data());
sorted_weights_dev.FromDevice(sorted_weights_host.data());
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
bool rtn = true;
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
local_expert_masking_host,
sorted_ids_ref,
sorted_weights_ref,
sorted_expert_ids_ref,
ref_total_tokens_post_pad,
num_experts,
unit_size,
is_local_token ? local_tokens
: tokens,
local_expert_masking);
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0])
{
size_t slen = ref_total_tokens_post_pad;
rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}),
sorted_ids_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect ids!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}),
sorted_weights_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}),
sorted_expert_ids_ref.slice({0}, {slen / unit_size}),
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
auto t_ = is_local_token ? local_tokens : tokens;
bool _f = t_ == sorted_id_cnt_host.mData[1];
rtn &= _f;
if(!_f)
{
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
}
}
else
{
printf("(token size not equal!!)");
rtn = false;
}
if(moe_buf_bytes)
{
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
#else
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
#endif
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}
printf("valid:%s", rtn ? "y" : "n");
fflush(stdout);
if(!rtn)
printf(", (%d)", seed);
printf("\n");
fflush(stdout);
EXPECT_TRUE(rtn);
}
template <typename PrecisionType>
static std::string get_precision_string()
{
if constexpr(std::is_same_v<PrecisionType, float>)
{
return "fp32";
}
else if(std::is_same_v<PrecisionType, ck_tile::index_t>)
{
return "int32";
}
else
{
throw std::runtime_error("Invalid precision.");
}
}
};

View File

@@ -3,7 +3,7 @@ if(GPU_TARGETS MATCHES "gfx9")
function (add_smoothquant_test TARGET_NAME MAIN_SRC)
message(DEBUG "adding ${TARGET_NAME}")
add_test_executable(${TARGET_NAME} ${MAIN_SRC})
add_gtest_executable(${TARGET_NAME} ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
@@ -20,8 +20,7 @@ if(GPU_TARGETS MATCHES "gfx9")
endfunction(add_smoothquant_test TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_test(test_ck_tile_smoothquant_fp16 smoothquant_fp16.cpp ${INSTANCE_SRCS})
add_smoothquant_test(test_ck_tile_smoothquant_bf16 smoothquant_bf16.cpp ${INSTANCE_SRCS})
add_smoothquant_test(test_ck_tile_smoothquant test_smoothquant.cpp ${INSTANCE_SRCS})
else()
message(DEBUG "Skipping ck_tile smoothquant tests for current target")

View File

@@ -22,9 +22,7 @@ using trait_ = smoothquant_traits_<DataType_,
kTwoPass_>;
template <typename data_type>
float smoothquant_dispatch(smoothquant_traits /*t*/,
smoothquant_args a,
const ck_tile::stream_config& s)
float smoothquant_dispatch(smoothquant_args a, const ck_tile::stream_config& s)
{
float r = -1;
// clang-format off
@@ -128,16 +126,14 @@ float smoothquant_dispatch(smoothquant_traits /*t*/,
// clang-format on
}
float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::stream_config& s)
template <>
float smoothquant<ck_tile::fp16_t>(smoothquant_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
return smoothquant_dispatch<ck_tile::fp16_t>(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return smoothquant_dispatch<ck_tile::bf16_t>(t, a, s);
}
else
throw std::runtime_error("Without supported instances!");
return smoothquant_dispatch<ck_tile::fp16_t>(a, s);
}
template <>
float smoothquant<ck_tile::bf16_t>(smoothquant_args a, const ck_tile::stream_config& s)
{
return smoothquant_dispatch<ck_tile::bf16_t>(a, s);
}

View File

@@ -111,4 +111,5 @@ struct smoothquant_traits
std::string data_type;
};
float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);
template <typename DataType>
float smoothquant(smoothquant_args, const ck_tile::stream_config&);

View File

@@ -1,273 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "smoothquant.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[], int index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv, index);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType;
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
smscale_buf.ToDevice(smscale_host.data());
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", y_stride:" << y_stride << std::flush;
smoothquant_traits traits{data_type};
smoothquant_args args{x_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
m,
n,
x_stride,
y_stride};
float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n +
sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier
{
auto f = [&](auto n_) {
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_)
{
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_smscale;
}
};
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency());
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(y_stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * y_stride +
n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
std::vector<std::vector<std::string>> create_test_cases(const std::string prec)
{
return {{"-prec=" + prec, "-m=99", "-n=13", "-x_stride=-1"},
{"-prec=" + prec, "-m=17", "-n=16", "-x_stride=-1"},
{"-prec=" + prec, "-m=1", "-n=100", "-x_stride=-1"},
{"-prec=" + prec, "-m=4", "-n=128", "-x_stride=-1"},
{"-prec=" + prec, "-m=80", "-n=127", "-x_stride=-1"},
{"-prec=" + prec, "-m=22", "-n=255", "-x_stride=256"},
{"-prec=" + prec, "-m=7", "-n=599", "-x_stride=-1"},
{"-prec=" + prec, "-m=19", "-n=512", "-x_stride=-1"},
{"-prec=" + prec, "-m=33", "-n=313", "-x_stride=1000"},
{"-prec=" + prec, "-m=11", "-n=510", "-x_stride=-1"},
{"-prec=" + prec, "-m=171", "-n=676", "-x_stride=818"},
{"-prec=" + prec, "-m=91", "-n=636", "-x_stride=-1"},
{"-prec=" + prec, "-m=12", "-n=768", "-x_stride=800"},
{"-prec=" + prec, "-m=100", "-n=766", "-x_stride=812"},
{"-prec=" + prec, "-m=31", "-n=1024", "-x_stride=-1"},
{"-prec=" + prec, "-m=64", "-n=1000", "-x_stride=1004"},
{"-prec=" + prec, "-m=8", "-n=1501", "-x_stride=-1"},
{"-prec=" + prec, "-m=3", "-n=1826", "-x_stride=-1"},
{"-prec=" + prec, "-m=5", "-n=2040", "-x_stride=-1"},
{"-prec=" + prec, "-m=7", "-n=2734", "-x_stride=-1"},
{"-prec=" + prec, "-m=1", "-n=3182", "-x_stride=-1"},
{"-prec=" + prec, "-m=9", "-n=4096", "-x_stride=-1"},
{"-prec=" + prec, "-m=3", "-n=8192", "-x_stride=-1"},
{"-prec=" + prec, "-m=1", "-n=10547", "-x_stride=-1"},
{"-prec=" + prec, "-m=3", "-n=17134", "-x_stride=-1"}};
}
template <typename DataType>
bool run_test_case(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
return run<DataType>(arg_parser);
}
template <typename DataType>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
constexpr int num_args = 4;
char* argv[num_args];
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
assert(test_cases[test_idx].size() == num_args &&
"invalid number of arguments in test case");
for(std::size_t idx = 0; idx < num_args; ++idx)
{
argv[idx] = test_cases[test_idx][idx].data();
}
valid = valid && run_test_case<DataType>(num_args, argv);
if(!valid)
break;
}
return valid;
}

View File

@@ -1,11 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("bf16");
return !run_test_cases<ck_tile::bf16_t>(test_cases);
}

View File

@@ -1,11 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("fp16");
return !run_test_cases<ck_tile::half_t>(test_cases);
}

View File

@@ -0,0 +1,14 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_smoothquant_types.hpp"
#include "test_smoothquant_util.hpp"
#include "gtest/gtest.h"
#define TEST_SUITE_NAME TestCkTileSmoothquant
TYPED_TEST_SUITE(TestCkTileSmoothquant, KernelTypesSmoothquant);
#include "test_smoothquant_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,206 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#ifndef TEST_SMOOTHQUANT_CASES_INC
#define TEST_SMOOTHQUANT_CASES_INC
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m99_n13)
{
ck_tile::index_t m = 99;
ck_tile::index_t n = 13;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m17_n16)
{
ck_tile::index_t m = 17;
ck_tile::index_t n = 16;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m1_n100)
{
ck_tile::index_t m = 1;
ck_tile::index_t n = 100;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m4_n128)
{
ck_tile::index_t m = 4;
ck_tile::index_t n = 128;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m80_n127)
{
ck_tile::index_t m = 80;
ck_tile::index_t n = 127;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m22_n255)
{
ck_tile::index_t m = 22;
ck_tile::index_t n = 255;
ck_tile::index_t x_stride = 256;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m7_n599)
{
ck_tile::index_t m = 7;
ck_tile::index_t n = 599;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m33_n313)
{
ck_tile::index_t m = 33;
ck_tile::index_t n = 313;
ck_tile::index_t x_stride = 1000;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m11_n510)
{
ck_tile::index_t m = 11;
ck_tile::index_t n = 510;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m171_n676)
{
ck_tile::index_t m = 171;
ck_tile::index_t n = 676;
ck_tile::index_t x_stride = 818;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m91_n636)
{
ck_tile::index_t m = 91;
ck_tile::index_t n = 636;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m12_n768)
{
ck_tile::index_t m = 12;
ck_tile::index_t n = 768;
ck_tile::index_t x_stride = 800;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m100_n766)
{
ck_tile::index_t m = 100;
ck_tile::index_t n = 766;
ck_tile::index_t x_stride = 812;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m31_n1024)
{
ck_tile::index_t m = 31;
ck_tile::index_t n = 1024;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m64_n1000)
{
ck_tile::index_t m = 64;
ck_tile::index_t n = 1000;
ck_tile::index_t x_stride = 1004;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m8_n1501)
{
ck_tile::index_t m = 8;
ck_tile::index_t n = 1501;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m3_n1826)
{
ck_tile::index_t m = 3;
ck_tile::index_t n = 1826;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m5_n2040)
{
ck_tile::index_t m = 5;
ck_tile::index_t n = 2040;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m7_n2734)
{
ck_tile::index_t m = 7;
ck_tile::index_t n = 2734;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m1_n3182)
{
ck_tile::index_t m = 1;
ck_tile::index_t n = 3182;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m9_n4096)
{
ck_tile::index_t m = 9;
ck_tile::index_t n = 4096;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m3_n8192)
{
ck_tile::index_t m = 3;
ck_tile::index_t n = 8192;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m1_n10547)
{
ck_tile::index_t m = 1;
ck_tile::index_t n = 10547;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m3_n17134)
{
ck_tile::index_t m = 3;
ck_tile::index_t n = 17134;
this->Run(m, n);
}
#endif

View File

@@ -0,0 +1,9 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "ck_tile/host.hpp"
#include "gtest/gtest.h"
using KernelTypesSmoothquant =
::testing::Types<std::tuple<ck_tile::fp16_t>, std::tuple<ck_tile::bf16_t>>;

View File

@@ -0,0 +1,181 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "smoothquant.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
template <typename Tuple>
class TestCkTileSmoothquant : public ::testing::Test
{
protected:
using DataType = std::tuple_element_t<0, Tuple>;
void Run(ck_tile::index_t m,
ck_tile::index_t n,
ck_tile::index_t x_stride = -1,
ck_tile::index_t y_stride = -1)
{
if(x_stride < 0)
x_stride = n;
if(y_stride < 0)
y_stride = n;
assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType;
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
smscale_buf.ToDevice(smscale_host.data());
std::cout << "m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", y_stride:" << y_stride << std::flush;
smoothquant_args args{x_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
m,
n,
x_stride,
y_stride};
smoothquant<DataType>(args, ck_tile::stream_config{nullptr, false});
bool pass = true;
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier
{
auto f = [&](auto n_) {
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_)
{
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_smscale;
}
};
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency());
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(y_stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * y_stride +
n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
EXPECT_TRUE(pass);
}
};