Merge commit '10395fc895a73727cf0bda5a44a88d1b2595dcb2' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-14 19:11:31 +00:00
parent 47789f61ee
commit eff48156cb
40 changed files with 3252 additions and 1966 deletions

View File

@@ -2,7 +2,7 @@
if(GPU_TARGETS MATCHES "gfx9")
function (add_moe_smoothquant_test TARGET_NAME MAIN_SRC)
message(DEBUG "adding ${TARGET_NAME}")
add_test_executable(${TARGET_NAME} ${MAIN_SRC})
add_gtest_executable(${TARGET_NAME} ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
@@ -21,11 +21,7 @@ if(GPU_TARGETS MATCHES "gfx9")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_moe_smoothquant_test(test_ck_tile_moe_smoothquant_fp16_fp8 moe_smoothquant_fp16_fp8.cpp ${INSTANCE_SRCS})
add_moe_smoothquant_test(test_ck_tile_moe_smoothquant_fp16_int8 moe_smoothquant_fp16_int8.cpp ${INSTANCE_SRCS})
add_moe_smoothquant_test(test_ck_tile_moe_smoothquant_bf16_fp8 moe_smoothquant_bf16_fp8.cpp ${INSTANCE_SRCS})
add_moe_smoothquant_test(test_ck_tile_moe_smoothquant_bf16_int8 moe_smoothquant_bf16_int8.cpp ${INSTANCE_SRCS})
add_moe_smoothquant_test(test_ck_tile_moe_smoothquant test_moe_smoothquant.cpp ${INSTANCE_SRCS})
else()
message(DEBUG "Skipping ck_tile MOE smoothquant tests for current target")

View File

@@ -24,9 +24,7 @@ using trait_ = moe_smoothquant_traits_<InType,
kTwoPass_>;
template <typename in_type, typename out_type>
float moe_smoothquant_dispatch(moe_smoothquant_traits /*t*/,
moe_smoothquant_args a,
const ck_tile::stream_config& s)
float moe_smoothquant_dispatch(moe_smoothquant_args a, const ck_tile::stream_config& s)
{
float r = -1;
// clang-format off
@@ -130,26 +128,30 @@ float moe_smoothquant_dispatch(moe_smoothquant_traits /*t*/,
// clang-format on
}
float moe_smoothquant(moe_smoothquant_traits t,
moe_smoothquant_args a,
const ck_tile::stream_config& s)
template <>
float moe_smoothquant<ck_tile::fp16_t, ck_tile::int8_t>(moe_smoothquant_args a,
const ck_tile::stream_config& s)
{
if(t.in_type.compare("fp16") == 0 && t.out_type == "int8")
{
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::int8_t>(t, a, s);
}
else if(t.in_type.compare("fp16") == 0 && t.out_type == "fp8")
{
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::fp8_t>(t, a, s);
}
else if(t.in_type.compare("bf16") == 0 && t.out_type == "int8")
{
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::int8_t>(t, a, s);
}
else if(t.in_type.compare("bf16") == 0 && t.out_type == "fp8")
{
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::fp8_t>(t, a, s);
}
else
throw std::runtime_error("Without supported instances!");
}
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::int8_t>(a, s);
};
template <>
float moe_smoothquant<ck_tile::fp16_t, ck_tile::fp8_t>(moe_smoothquant_args a,
const ck_tile::stream_config& s)
{
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::fp8_t>(a, s);
};
template <>
float moe_smoothquant<ck_tile::bf16_t, ck_tile::int8_t>(moe_smoothquant_args a,
const ck_tile::stream_config& s)
{
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::int8_t>(a, s);
};
template <>
float moe_smoothquant<ck_tile::bf16_t, ck_tile::fp8_t>(moe_smoothquant_args a,
const ck_tile::stream_config& s)
{
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::fp8_t>(a, s);
};

View File

@@ -95,10 +95,5 @@ template <typename Traits_>
float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
// This is the public API, will be generated by script
struct moe_smoothquant_traits
{
std::string in_type; // input type
std::string out_type; // output type
};
float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);
template <typename InputType, typename OutputType>
float moe_smoothquant(moe_smoothquant_args, const ck_tile::stream_config&);

View File

@@ -1,317 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "moe_smoothquant.hpp"
#include <cstring>
#include <set>
#include <hip/hip_runtime.h>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
auto create_args(int argc, char* argv[], int index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "3328", "tokens dimension")
.insert("h", "4096", "hidden_size dimension")
.insert("e", "32", "experts")
.insert("k", "5", "topk")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec_i", "fp16", "input precision, fp16/bf16")
.insert("prec_o", "int8", "precision, int8/fp8")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv, index);
return std::make_tuple(result, arg_parser);
}
template <typename InputType, typename OutputType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t hidden_size = arg_parser.get_int("h");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = hidden_size;
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= hidden_size);
using TypeConfig = MoeSmoothquantTypeConfig<InputType, OutputType>;
using XDataType = typename TypeConfig::XDataType;
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({experts * hidden_size});
ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({topk * tokens}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({topk * tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({topk * tokens, hidden_size}, {stride, 1});
topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937);
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
smscale_buf.ToDevice(smscale_host.data());
topk_ids_buf.ToDevice(topk_ids_host.data());
std::cout << "[" << prec_i << "-" << prec_o << "]" << " tokens:" << tokens
<< ", hidden_size:" << hidden_size << ", stride:" << stride << ", experts:" << experts
<< ", topk:" << topk << std::flush;
moe_smoothquant_traits traits{prec_i, prec_o};
moe_smoothquant_args args{x_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
tokens,
hidden_size,
experts,
topk,
stride,
stride};
float ave_time = moe_smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * tokens * hidden_size +
sizeof(SmoothScaleDataType) * topk * hidden_size +
sizeof(YScaleDataType) * topk * tokens +
sizeof(QYDataType) * topk * tokens * hidden_size;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({topk * tokens, hidden_size}, {stride, 1});
// smooth outlier
{
auto f = [&](auto i_token) {
for(int i_topk = 0; i_topk < topk; i_topk++)
{
auto i_expert = topk_ids_host(i_token, i_topk);
for(int i_h = 0; i_h < hidden_size; ++i_h)
{
auto v_smscale = ck_tile::type_convert<ComputeDataType>(
smscale_host(i_expert * hidden_size + i_h));
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h));
// y_host(i_token * topk + i_topk, i_h) = v_x * v_smscale;
y_host(i_topk * tokens + i_token, i_h) = v_x * v_smscale;
}
}
};
ck_tile::make_ParallelTensorFunctor(f, tokens)(std::thread::hardware_concurrency());
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({topk * tokens});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == hidden_size)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < topk * tokens; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride +
hidden_size);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride +
hidden_size);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
std::vector<std::vector<std::string>> generate_test_cases(const std::string prec_in,
const std::string prec_out)
{
return {{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=99", "-h=13", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=17", "-h=16", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=1", "-h=100", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=4", "-h=128", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=80", "-h=127", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=22", "-h=255", "-stride=256"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=7", "-h=599", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=19", "-h=512", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=33", "-h=313", "-stride=1000"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=11", "-h=510", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=171", "-h=676", "-stride=818"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=12", "-h=768", "-stride=800"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=100", "-h=766", "-stride=812"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=31", "-h=1024", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=64", "-h=1000", "-stride=1004"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=8", "-h=1501", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=3", "-h=1826", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=5", "-h=2040", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=7", "-h=2734", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=1", "-h=3182", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=9", "-h=4096", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=3", "-h=8192", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=1", "-h=10547", "-stride=-1"},
{"-prec_i=" + prec_in, "-prec_o=" + prec_out, "-t=3", "-h=17134", "-stride=-1"}};
}
template <typename InputType, typename OutputType>
bool run_test_case(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
return run<InputType, OutputType>(arg_parser);
}
template <typename InputType, typename OutputType>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
constexpr int num_args = 5;
char* argv[num_args];
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
assert(num_args == test_cases[test_idx].size() && "invalid number of arguments");
for(int arg_idx = 0; arg_idx < num_args; ++arg_idx)
{
argv[arg_idx] = test_cases[test_idx][arg_idx].data();
}
valid = valid && run_test_case<InputType, OutputType>(num_args, argv);
if(!valid)
break;
}
return valid;
}

View File

@@ -1,11 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "moe_smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = generate_test_cases("bf16", "fp8");
return !run_test_cases<ck_tile::bf16_t, ck_tile::fp8_t>(test_cases);
}

View File

@@ -1,11 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "moe_smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = generate_test_cases("bf16", "int8");
return !run_test_cases<ck_tile::bf16_t, ck_tile::int8_t>(test_cases);
}

View File

@@ -1,11 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "moe_smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = generate_test_cases("fp16", "fp8");
return !run_test_cases<ck_tile::half_t, ck_tile::fp8_t>(test_cases);
}

View File

@@ -1,11 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "moe_smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = generate_test_cases("fp16", "int8");
return !run_test_cases<ck_tile::half_t, ck_tile::int8_t>(test_cases);
}

View File

@@ -0,0 +1,14 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_moe_smoothquant_types.hpp"
#include "test_moe_smoothquant_util.hpp"
#include "gtest/gtest.h"
#define TEST_SUITE_NAME TestCkTileMoeSmoothquant
TYPED_TEST_SUITE(TestCkTileMoeSmoothquant, KernelTypesMoeSmoothquant);
#include "test_moe_smoothquant_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,206 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#ifndef TEST_MOE_SMOOTHQUANT_CASES_INC
#define TEST_MOE_SMOOTHQUANT_CASES_INC
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t99_h13)
{
ck_tile::index_t tokens = 99;
ck_tile::index_t hidden_size = 13;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t17_h16)
{
ck_tile::index_t tokens = 17;
ck_tile::index_t hidden_size = 16;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t1_h100)
{
ck_tile::index_t tokens = 1;
ck_tile::index_t hidden_size = 100;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t4_h128)
{
ck_tile::index_t tokens = 4;
ck_tile::index_t hidden_size = 128;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t80_h127)
{
ck_tile::index_t tokens = 80;
ck_tile::index_t hidden_size = 127;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t22_h255)
{
ck_tile::index_t tokens = 22;
ck_tile::index_t hidden_size = 255;
ck_tile::index_t stride = 256;
this->Run(tokens, hidden_size, stride);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t7_h599)
{
ck_tile::index_t tokens = 7;
ck_tile::index_t hidden_size = 599;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t19_h512)
{
ck_tile::index_t tokens = 19;
ck_tile::index_t hidden_size = 512;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t33_h313)
{
ck_tile::index_t tokens = 33;
ck_tile::index_t hidden_size = 313;
ck_tile::index_t stride = 1000;
this->Run(tokens, hidden_size, stride);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t11_h510)
{
ck_tile::index_t tokens = 11;
ck_tile::index_t hidden_size = 510;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t171_h676)
{
ck_tile::index_t tokens = 171;
ck_tile::index_t hidden_size = 676;
ck_tile::index_t stride = 818;
this->Run(tokens, hidden_size, stride);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t12_h768)
{
ck_tile::index_t tokens = 12;
ck_tile::index_t hidden_size = 768;
ck_tile::index_t stride = 800;
this->Run(tokens, hidden_size, stride);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t100_h766)
{
ck_tile::index_t tokens = 100;
ck_tile::index_t hidden_size = 766;
ck_tile::index_t stride = 812;
this->Run(tokens, hidden_size, stride);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t31_h1024)
{
ck_tile::index_t tokens = 31;
ck_tile::index_t hidden_size = 1024;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t64_h1000)
{
ck_tile::index_t tokens = 64;
ck_tile::index_t hidden_size = 1000;
ck_tile::index_t stride = 1004;
this->Run(tokens, hidden_size, stride);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t8_h1501)
{
ck_tile::index_t tokens = 8;
ck_tile::index_t hidden_size = 1501;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t3_h1826)
{
ck_tile::index_t tokens = 3;
ck_tile::index_t hidden_size = 1826;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t5_h2040)
{
ck_tile::index_t tokens = 5;
ck_tile::index_t hidden_size = 2040;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t7_h2734)
{
ck_tile::index_t tokens = 7;
ck_tile::index_t hidden_size = 2734;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t1_h3182)
{
ck_tile::index_t tokens = 1;
ck_tile::index_t hidden_size = 3182;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t9_h4096)
{
ck_tile::index_t tokens = 9;
ck_tile::index_t hidden_size = 4096;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t3_h8192)
{
ck_tile::index_t tokens = 3;
ck_tile::index_t hidden_size = 8192;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t1_h10547)
{
ck_tile::index_t tokens = 1;
ck_tile::index_t hidden_size = 10547;
this->Run(tokens, hidden_size);
}
TYPED_TEST(TEST_SUITE_NAME, MoeSmoothquant_t3_h17134)
{
ck_tile::index_t tokens = 3;
ck_tile::index_t hidden_size = 17134;
this->Run(tokens, hidden_size);
}
#endif

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "ck_tile/host.hpp"
#include "gtest/gtest.h"
using KernelTypesMoeSmoothquant = ::testing::Types<std::tuple<ck_tile::bf16_t, ck_tile::fp8_t>,
std::tuple<ck_tile::bf16_t, ck_tile::int8_t>,
std::tuple<ck_tile::fp16_t, ck_tile::fp8_t>,
std::tuple<ck_tile::fp16_t, ck_tile::int8_t>>;

View File

@@ -0,0 +1,218 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "moe_smoothquant.hpp"
#include <cstring>
#include <set>
#include <hip/hip_runtime.h>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
template <typename Tuple>
class TestCkTileMoeSmoothquant : public ::testing::Test
{
protected:
using InputType = std::tuple_element_t<0, Tuple>;
using OutputType = std::tuple_element_t<1, Tuple>;
void Run(ck_tile::index_t tokens,
ck_tile::index_t hidden_size,
ck_tile::index_t stride = -1,
ck_tile::index_t experts = 32,
ck_tile::index_t topk = 5)
{
if(stride < 0)
stride = hidden_size;
assert(stride >= hidden_size);
using TypeConfig = MoeSmoothquantTypeConfig<InputType, OutputType>;
using XDataType = typename TypeConfig::XDataType;
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({experts * hidden_size});
ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({topk * tokens}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({topk * tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({topk * tokens, hidden_size}, {stride, 1});
topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937);
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
smscale_buf.ToDevice(smscale_host.data());
topk_ids_buf.ToDevice(topk_ids_host.data());
std::cout << "tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride
<< ", experts:" << experts << ", topk:" << topk << std::flush;
moe_smoothquant_args args{x_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
tokens,
hidden_size,
experts,
topk,
stride,
stride};
moe_smoothquant<InputType, OutputType>(args, ck_tile::stream_config{nullptr, false});
bool pass = true;
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({topk * tokens, hidden_size}, {stride, 1});
// smooth outlier
{
auto f = [&](auto i_token) {
for(int i_topk = 0; i_topk < topk; i_topk++)
{
auto i_expert = topk_ids_host(i_token, i_topk);
for(int i_h = 0; i_h < hidden_size; ++i_h)
{
auto v_smscale = ck_tile::type_convert<ComputeDataType>(
smscale_host(i_expert * hidden_size + i_h));
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h));
// y_host(i_token * topk + i_topk, i_h) = v_x * v_smscale;
y_host(i_topk * tokens + i_token, i_h) = v_x * v_smscale;
}
}
};
ck_tile::make_ParallelTensorFunctor(f, tokens)(std::thread::hardware_concurrency());
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({topk * tokens});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == hidden_size)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < topk * tokens; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride +
hidden_size);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride +
hidden_size);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
EXPECT_TRUE(pass);
}
};

View File

@@ -1,14 +1,19 @@
# Currently ck_tile is only built on gfx90a, gfx942 and gfx950
if(GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx950" OR GPU_TARGETS MATCHES "gfx90a")
add_test_executable(test_ck_tile_moe_sorting_fp32 moe_sorting_fp32.cpp moe_sorting_api.cpp)
target_include_directories(test_ck_tile_moe_sorting_fp32 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
function(add_moe_sorting_test EXECUTABLE USE_2D_BUF)
add_gtest_executable(${EXECUTABLE} test_moe_sorting.cpp moe_sorting_api.cpp)
target_include_directories(${EXECUTABLE} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(test_ck_tile_moe_sorting_fp32 PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -DMOE_SORTING_FMOE_2D_BUF=${USE_2D_BUF})
target_compile_options(${EXECUTABLE} PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
endfunction(add_moe_sorting_test EXECUTABLE USE_2D_BUF)
add_moe_sorting_test(test_ck_tile_moe_sorting_2d_buf 1)
add_moe_sorting_test(test_ck_tile_moe_sorting 0)
else()
message(DEBUG "Skipping ck_tile_moe_sorting tests for current target")

View File

@@ -1,544 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <set>
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
auto create_args(int argc, char* argv[], int index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "turn CPU validation on (1) or off (0).")
.insert("pr_i", "int32", "index data type. Only int32 is currently supported.")
.insert("pr_w", "fp32", "output weight data type. Only fp32 is currently supported.")
.insert("t",
"128",
"number of input tokens.\n"
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
.insert(
"local_t",
"-1",
"Number of local input tokens for curent rank.\n"
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
"This feature is to simulate EP case where where each rank has different tokens.\n"
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
.insert("e", "8", "number of num_experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
#if MOE_SORTING_FMOE_2D_BUF
.insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf")
.insert(
"moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...")
#else
.insert("moe_buf_size", "0", "moe_buf_size")
#endif
.insert("ci",
"1",
"clear workspace inside API or not(if \"0\", require manually clear outside)")
.insert(
"dispatch",
"0",
"dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel")
.insert("local_eid",
"-1",
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
"please make sure eid is in ascending order!")
.insert("seed",
"-1",
"seed to be used. When set to -1, a random seed will be generated each time "
"invoking this example")
.insert("kname", "0", "prints the kernel name when set to 1")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv, index);
return std::make_tuple(result, arg_parser);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool test_moe_sorting(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int local_tokens = args.get_int("local_t");
int num_experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
int unit_size = args.get_int("unit");
#if MOE_SORTING_FMOE_2D_BUF
int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim");
int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes");
#else
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
#endif
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
bool clear_inside = args.get_int("ci") != 0;
int dispatch_policy = args.get_int("dispatch");
int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
if(seed < 0)
{
seed = std::time(nullptr);
}
if(topk > num_experts)
{
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
topk,
num_experts);
return false;
}
// if local_tokens == tokens, not local_token, but better avoid this since no meaning for such
// case
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
if(local_tokens > tokens)
{
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
return false;
}
bool local_expert_masking = args.get_str("local_eid") != "-1";
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
{
auto local_eid = args.get_int_vec("local_eid");
ck_tile::HostTensor<IndexType> v_{{num_experts}};
v_.SetZero();
for(auto eid : local_eid)
{
if(eid >= num_experts)
{
throw std::runtime_error(
"local_eid larger than number of expert, please check");
}
v_.mData[eid] = 1;
}
return v_;
}
else
return ck_tile::HostTensor<IndexType>{{1}};
}();
// tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
// for simplicity, below buffer allocate 2 dword
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_host(
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
moe_buf_elem_bytes});
auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#else
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#endif
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
#else
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
#endif
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_dev(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
// used for simulating dynamic_tokens for EP case
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.ToDevice(moe_buf_host.data());
}
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size =
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0 && clear_inside == false)
moe_sorting_ws.SetZero(); // note, clear here!!!!
moe_sorting_trait trait{
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
num_experts,
topk,
#if MOE_SORTING_FMOE_2D_BUF
moe_buf_interm_dim,
moe_buf_elem_bytes
#else
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
#endif
};
ck_tile::stream_config sc{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s|%s|%d]tokens:%d",
index_prec.c_str(),
weight_prec.c_str(),
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
dispatch_policy,
tokens);
if(is_local_token)
{
printf("(%d)", local_tokens);
}
printf(", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0);
if(local_expert_masking)
{
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
}
if(moe_buf_bytes > 0)
{
#if MOE_SORTING_FMOE_2D_BUF
printf("moe_buf:%lu(%d,%d), ",
static_cast<uint64_t>(moe_buf_bytes),
moe_buf_interm_dim,
moe_buf_elem_bytes);
#else
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
#endif
}
if(ms < 0)
printf("not supported\n");
else
printf("ms:%f, ", ms);
fflush(stdout);
if(ms < 0)
{
return false;
}
sorted_ids_dev.FromDevice(sorted_ids_host.data());
sorted_weights_dev.FromDevice(sorted_weights_host.data());
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
local_expert_masking_host,
sorted_ids_ref,
sorted_weights_ref,
sorted_expert_ids_ref,
ref_total_tokens_post_pad,
num_experts,
unit_size,
is_local_token ? local_tokens
: tokens,
local_expert_masking);
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0])
{
size_t slen = ref_total_tokens_post_pad;
rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}),
sorted_ids_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect ids!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}),
sorted_weights_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}),
sorted_expert_ids_ref.slice({0}, {slen / unit_size}),
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
// if(is_local_token)
{
auto t_ = is_local_token ? local_tokens : tokens;
bool _f = t_ == sorted_id_cnt_host.mData[1];
rtn &= _f;
if(!_f)
{
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
}
}
}
else
{
printf("(token size not equal!!)");
rtn = false;
}
if(moe_buf_bytes)
{
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
#else
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
#endif
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}
// rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
printf("valid:%s", rtn ? "y" : "n");
fflush(stdout);
if(!rtn)
printf(", (%d)", seed);
printf("\n");
fflush(stdout);
return rtn;
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool run_test_case(int argc, char* argv[])
{
auto [result, args] = create_args(argc, argv);
if(!result)
return false;
return test_moe_sorting<WeightType, IndexType>(args);
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
constexpr int max_num_args = 7;
const int num_args = test_cases[test_idx].size();
assert(max_num_args >= num_args && "Invalid number of arguments in test case");
char* argv[max_num_args];
for(int arg_idx = 0; arg_idx < num_args; ++arg_idx)
{
argv[arg_idx] = test_cases[test_idx][arg_idx].data();
}
try
{
valid = valid && run_test_case<WeightType, IndexType>(num_args, argv);
if(!valid)
break;
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return false;
}
}
return valid;
}
std::vector<std::vector<std::string>> create_test_cases()
{
#if MOE_SORTING_FMOE_2D_BUF
return {{"-t=80", "-e=17", "-moe_buf_interm_dim=16", "-moe_buf_elem_bytes=4"},
{"-t=111", "-e=117", "-moe_buf_interm_dim=4", "-moe_buf_elem_bytes=4"},
{"-t=1000", "-e=55", "-moe_buf_interm_dim=1024", "-moe_buf_elem_bytes=1"},
{"-t=99", "-e=120", "-moe_buf_interm_dim=10244", "-moe_buf_elem_bytes=2"},
{"-t=175", "-e=64", "-k=8"},
{"-t=65", "-e=8", "-k=2"},
{"-t=1", "-e=25"},
{"-t=31", "-e=19", "-k=15"},
{"-t=81", "-e=37", "-k=7"},
{"-t=23", "-e=1", "-k=1"},
{"-t=127", "-e=99", "-k=19"},
{"-t=71", "-e=11", "-k=11"},
{"-t=1", "-e=1", "-k=1"},
{"-t=99", "-e=2", "-k=1"},
{"-t=333", "-e=99", "-k=13"},
{"-t=11", "-e=256", "-k=5"},
{"-t=64", "-e=455", "-k=8"},
{"-t=777", "-e=802", "-k=99"},
{"-t=4097", "-e=906", "-k=51"},
{"-t=128", "-e=32", "-k=5", "-local_t=6", "-moe_buf_interm_dim=262144"},
{"-t=13", "-e=64", "-k=3", "-local_eid=4,5,6,7,8,9,10,11"},
{"-t=99", "-e=33", "-k=9", "-local_eid=6,10,11,15,19"},
{"-t=80", "-e=99", "-k=10", "-local_eid=0,8,12,33"},
{"-t=11", "-e=256", "-k=5", "-local_eid=99,110,129"},
{"-t=128", "-e=128", "-k=6", "-moe_buf_interm_dim=163840", "-moe_buf_elem_bytes=1"},
{"-t=8192", "-e=32", "-k=5", "-local_t=11", "-moe_buf_interm_dim=163840"},
{"-t=8192",
"-e=32",
"-k=8",
"-local_t=12",
"-moe_buf_interm_dim=163840",
"-moe_buf_elem_bytes=1"},
{"-t=8192", "-e=256", "-k=5", "-local_t=13", "-moe_buf_interm_dim=163840"},
{"-t=8192", "-e=256", "-k=8", "-local_t=8", "-moe_buf_interm_dim=163840"},
{"-t=163840",
"-e=256",
"-k=8",
"-local_t=4",
"-moe_buf_interm_dim=163840",
"-moe_buf_elem_bytes=4"},
{"-t=12", "-local_t=3", "-e=256", "-k=5", "-local_eid=9,10,199,145"},
{"-t=67", "-local_t=9", "-e=555", "-k=5", "-local_eid=19,23,24,25,26,99"},
{"-t=99", "-local_t=93", "-e=121", "-local_t=4", "-moe_buf_interm_dim=10244"},
{"-t=536", "-local_t=345", "-e=802", "-k=99"},
{"-t=331", "-local_t=39", "-e=83", "-k=33"},
{"-t=765", "-local_t=654", "-e=783", "-k=8"},
{"-t=23", "-local_t=9", "-e=1", "-k=1"},
{"-t=7", "-local_t=0", "-e=89", "-k=1", "-local_eid=0,8,12,33"},
{"-t=61", "-local_t=0", "-e=333", "-k=99", "-local_eid=0,8,12,33"},
{"-t=133940",
"-local_t=111921",
"-e=256",
"-k=17",
"-local_t=2",
"-moe_buf_interm_dim=133940",
"-moe_buf_elem_bytes=1"}};
#else
return {{"-t=80", "-e=17", "-moe_buf_size=16"},
{"-t=111", "-e=117", "-moe_buf_size=4"},
{"-t=1000", "-e=55", "-moe_buf_size=1024"},
{"-t=99", "-e=120", "-moe_buf_size=10244"},
{"-t=175", "-e=64", "-k=8"},
{"-t=65", "-e=8", "-k=2"},
{"-t=1", "-e=25"},
{"-t=31", "-e=19", "-k=15"},
{"-t=81", "-e=37", "-k=7"},
{"-t=23", "-e=1", "-k=1"},
{"-t=127", "-e=99", "-k=19"},
{"-t=71", "-e=11", "-k=11"},
{"-t=1", "-e=1", "-k=1"},
{"-t=99", "-e=2", "-k=1"},
{"-t=333", "-e=99", "-k=13"},
{"-t=11", "-e=256", "-k=5"},
{"-t=64", "-e=455", "-k=8"},
{"-t=777", "-e=802", "-k=99"},
{"-t=4097", "-e=906", "-k=51"},
{"-t=128", "-e=32", "-k=5", "-moe_buf_size=262144"},
{"-t=13", "-e=64", "-k=3", "-local_eid=4,5,6,7,8,9,10,11"},
{"-t=99", "-e=33", "-k=9", "-local_eid=6,10,11,15,19"},
{"-t=80", "-e=99", "-k=10", "-local_eid=0,8,12,33"},
{"-t=11", "-e=256", "-k=5", "-local_eid=99,110,129"},
{"-t=128", "-e=128", "-k=6", "-moe_buf_size=163840"},
{"-t=8192", "-e=32", "-k=5", "-moe_buf_size=163840"},
{"-t=8192", "-e=32", "-k=8", "-moe_buf_size=163840"},
{"-t=8192", "-e=256", "-k=5", "-moe_buf_size=163840"},
{"-t=8192", "-e=256", "-k=8", "-moe_buf_size=163840"},
{"-t=163840", "-e=256", "-k=8", "-moe_buf_size=163840"},
{"-t=12", "-local_t=3", "-e=256", "-k=5", "-local_eid=9,10,199,145"},
{"-t=67", "-local_t=9", "-e=555", "-k=5", "-local_eid=19,23,24,25,26,99"},
{"-t=99", "-local_t=93", "-e=121", "-moe_buf_size=10244"},
{"-t=536", "-local_t=345", "-e=802", "-k=99"},
{"-t=331", "-local_t=39", "-e=83", "-k=33"},
{"-t=765", "-local_t=654", "-e=783", "-k=8"},
{"-t=23", "-local_t=9", "-e=1", "-k=1"},
{"-t=7", "-local_t=0", "-e=89", "-k=1", "-local_eid=0,8,12,33"},
{"-t=61", "-local_t=0", "-e=333", "-k=99", "-local_eid=0,8,12,33"},
{"-t=133940", "-local_t=111921", "-e=256", "-k=17", "-moe_buf_size=133940"}};
#endif
}
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases();
return !run_test_cases<float, ck_tile::index_t>(test_cases);
}

View File

@@ -0,0 +1,14 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_moe_sorting_types.hpp"
#include "test_moe_sorting_util.hpp"
#include "gtest/gtest.h"
#define TEST_SUITE_NAME TestCkTileMoeSorting
TYPED_TEST_SUITE(TestCkTileMoeSorting, KernelTypesMoeSorting);
#include "test_moe_sorting_cases.inc"
#undef TEST_SUITE_NAME

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "ck_tile/host.hpp"
#include "gtest/gtest.h"
using KernelTypesMoeSorting = ::testing::Types<std::tuple<float, ck_tile::index_t>>;

View File

@@ -0,0 +1,356 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <set>
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
void print_vector(std::vector<int>& data)
{
for(const auto& x : data)
{
std::cout << x << ",";
}
std::cout << " ";
}
template <typename Tuple>
class TestCkTileMoeSorting : public ::testing::Test
{
protected:
using WeightType = std::tuple_element_t<0, Tuple>;
using IndexType = std::tuple_element_t<1, Tuple>;
void RunSingle(int tokens,
int local_tokens,
int num_experts,
int topk,
int unit_size,
std::vector<int>& local_eid,
#if MOE_SORTING_FMOE_2D_BUF
int moe_buf_interm_dim,
int moe_buf_elem_bytes)
#else
int64_t moe_buf_size)
#endif
{
std::string index_prec = get_precision_string<IndexType>();
std::string weight_prec = get_precision_string<WeightType>();
bool clear_inside = true;
int dispatch_policy = 0;
int max_output_ids = ck_tile::integer_least_multiple(
topk * tokens + num_experts * unit_size - topk, unit_size);
int seed = 42; // Fixed seed for testing reproducibility
if(topk > num_experts)
{
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
topk,
num_experts);
EXPECT_TRUE(false);
}
// if local_tokens == tokens, not local_token, but better avoid this since no meaning for
// such case
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
if(local_tokens > tokens)
{
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
EXPECT_TRUE(false);
}
bool local_expert_masking = !local_eid.empty();
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
{
// auto local_eid = args.get_int_vec("local_eid");
ck_tile::HostTensor<IndexType> v_{{num_experts}};
v_.SetZero();
for(auto eid : local_eid)
{
if(eid >= num_experts)
{
throw std::runtime_error(
"local_eid larger than number of expert, please check");
}
v_.mData[eid] = 1;
}
return v_;
}
else
return ck_tile::HostTensor<IndexType>{{1}};
}();
// tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
// for simplicity, below buffer allocate 2 dword
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_host(
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
moe_buf_elem_bytes});
auto moe_buf_bytes = moe_buf_interm_dim == 0
? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#else
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#endif
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
#else
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
#endif
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weights_dev(
sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_dev(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
// used for simulating dynamic_tokens for EP case
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
if(is_local_token)
{
local_tokens_dev.ToDevice(&local_tokens);
}
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.ToDevice(moe_buf_host.data());
}
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size =
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0 && clear_inside == false)
moe_sorting_ws.SetZero(); // note, clear here!!!!
moe_sorting_trait trait{
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
num_experts,
topk,
#if MOE_SORTING_FMOE_2D_BUF
moe_buf_interm_dim,
moe_buf_elem_bytes
#else
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
#endif
};
ck_tile::stream_config sc{nullptr, false};
auto ret_val = moe_sorting(trait, karg, sc);
printf("[%s|%s|%s|%d]tokens:%d",
index_prec.c_str(),
weight_prec.c_str(),
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
dispatch_policy,
tokens);
if(is_local_token)
{
printf("(%d)", local_tokens);
}
printf(
", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0);
if(local_expert_masking)
{
printf("local_eid:");
print_vector(local_eid);
}
if(moe_buf_bytes > 0)
{
#if MOE_SORTING_FMOE_2D_BUF
printf("moe_buf:%lu(%d,%d), ",
static_cast<uint64_t>(moe_buf_bytes),
moe_buf_interm_dim,
moe_buf_elem_bytes);
#else
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
#endif
}
if(ret_val < 0)
{
printf("not supported\n");
fflush(stdout);
EXPECT_TRUE(false);
}
sorted_ids_dev.FromDevice(sorted_ids_host.data());
sorted_weights_dev.FromDevice(sorted_weights_host.data());
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
if(moe_buf_bytes > 0)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
bool rtn = true;
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
local_expert_masking_host,
sorted_ids_ref,
sorted_weights_ref,
sorted_expert_ids_ref,
ref_total_tokens_post_pad,
num_experts,
unit_size,
is_local_token ? local_tokens
: tokens,
local_expert_masking);
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
if(ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0])
{
size_t slen = ref_total_tokens_post_pad;
rtn &= ck_tile::check_err(sorted_ids_host.slice({0}, {slen}),
sorted_ids_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect ids!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_weights_host.slice({0}, {slen}),
sorted_weights_ref.slice({0}, {slen}),
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host.slice({0}, {slen / unit_size}),
sorted_expert_ids_ref.slice({0}, {slen / unit_size}),
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
auto t_ = is_local_token ? local_tokens : tokens;
bool _f = t_ == sorted_id_cnt_host.mData[1];
rtn &= _f;
if(!_f)
{
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
}
}
else
{
printf("(token size not equal!!)");
rtn = false;
}
if(moe_buf_bytes)
{
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
#else
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
#endif
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}
printf("valid:%s", rtn ? "y" : "n");
fflush(stdout);
if(!rtn)
printf(", (%d)", seed);
printf("\n");
fflush(stdout);
EXPECT_TRUE(rtn);
}
template <typename PrecisionType>
static std::string get_precision_string()
{
if constexpr(std::is_same_v<PrecisionType, float>)
{
return "fp32";
}
else if(std::is_same_v<PrecisionType, ck_tile::index_t>)
{
return "int32";
}
else
{
throw std::runtime_error("Invalid precision.");
}
}
};

View File

@@ -2,7 +2,7 @@
if(GPU_TARGETS MATCHES "gfx9")
function(add_permute_test TARGET_NAME MAIN_SRC)
add_test_executable(${TARGET_NAME} ${MAIN_SRC})
add_gtest_executable(${TARGET_NAME} ${MAIN_SRC})
if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL)
set(PERMUTE_USE_ALTERNATIVE_IMPL true)
@@ -10,23 +10,11 @@ if(GPU_TARGETS MATCHES "gfx9")
if(PERMUTE_USE_ALTERNATIVE_IMPL)
target_compile_options(${TARGET_NAME} PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL)
target_sources(${TARGET_NAME} PRIVATE alternative_impl/matrix_core_swizzle.cpp)
endif()
endfunction(add_permute_test TARGET_NAME MAIN_SRC)
set(CUSTOM_TARGET_NAME test_ck_tile_permute)
add_custom_target(${CUSTOM_TARGET_NAME})
add_permute_test(test_ck_tile_permute_fp16 permute_fp16.cpp)
add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_permute_fp16)
add_permute_test(test_ck_tile_permute_fp8 permute_fp8.cpp)
add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_permute_fp8)
add_permute_test(test_ck_tile_permute_fp32 permute_fp32.cpp)
add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_permute_fp32)
add_permute_test(test_ck_tile_permute test_permute.cpp)
else()
message(DEBUG "Skipping ck_tile_permute tests for current target")

View File

@@ -1,101 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "matrix_core_swizzle.hpp"
#include "matrix_core_swizzle_kernel.hpp"
float matrix_core_swizzle(matrix_core_swizzle_traits t,
matrix_core_swizzle_args a,
const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
if(t.inst.compare("32x32x8") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
else if(t.inst.compare("16x16x16") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
}
return -1;
}

View File

@@ -7,14 +7,125 @@
struct matrix_core_swizzle_traits
{
std::string data_type; // fp16 only
std::string inst; // 32x32x8, 16x16x16
std::string permute; //
std::string inst; // 32x32x8, 16x16x16
std::string permute;
};
using matrix_core_swizzle_args = matrix_core_swizzle_host_args;
// host API
template <typename DataType> // only supported with fp16 data type
float matrix_core_swizzle(matrix_core_swizzle_traits,
matrix_core_swizzle_args,
const ck_tile::stream_config&);
template <>
float matrix_core_swizzle<ck_tile::half_t>(matrix_core_swizzle_traits t,
matrix_core_swizzle_args a,
const ck_tile::stream_config& s)
{
if(t.inst.compare("32x32x8") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
else if(t.inst.compare("16x16x16") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
return -1;
}
template <>
float matrix_core_swizzle<ck_tile::fp8_t>(matrix_core_swizzle_traits,
matrix_core_swizzle_args,
const ck_tile::stream_config&)
{
throw std::runtime_error("Not supported for fp8");
}
template <>
float matrix_core_swizzle<float>(matrix_core_swizzle_traits,
matrix_core_swizzle_args,
const ck_tile::stream_config&)
{
throw std::runtime_error("Not supported for fp32");
}

View File

@@ -8,12 +8,4 @@
#include "ck_tile/ops/permute.hpp"
#include <string>
struct permute_traits
{
std::string data_type;
};
using permute_args = ck_tile::GenericPermuteHostArgs;
// host API
float permute(permute_traits, permute_args, const ck_tile::stream_config&);

View File

@@ -1,29 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include <array>
#include <cassert>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
#include "permute_utils.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases_fp16();
return !run_test_cases<ck_tile::half_t>(test_cases);
}

View File

@@ -1,29 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include <array>
#include <cassert>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
#include "permute_utils.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("fp32");
return !run_test_cases<float>(test_cases);
}

View File

@@ -1,29 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include <array>
#include <cassert>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
#include "permute_utils.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("fp8");
return !run_test_cases<ck_tile::fp8_t>(test_cases);
}

View File

@@ -1,490 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace detail {
template <int bytes>
struct to_integer_type;
template <>
struct to_integer_type<4>
{
using type = int32_t;
};
template <>
struct to_integer_type<2>
{
using type = int16_t;
};
template <>
struct to_integer_type<1>
{
using type = int8_t;
};
} // namespace detail
template <int bytes>
using to_integer_type = typename detail::to_integer_type<bytes>::type;
// host API (shoule come from codegen)
float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp8") == 0)
{
using DataType = ck_tile::fp8_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp16") == 0)
{
using DataType = ck_tile::half_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp32") == 0)
{
using DataType = float;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[], int start_index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)")
.insert("shape", "2,3,4", "the shape of the input tensor")
.insert("perm", "2,1,0", "permute perm")
.insert("kname", "0", "t to 1 will print kernel name")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv, start_index);
return std::make_tuple(result, arg_parser);
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
// "1,2,3,4" -> vector{1,2,3,4}
std::vector<ck_tile::index_t> decode_vec(std::string q_val)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
std::string::size_type pos = 0;
std::vector<ck_tile::index_t> v;
while(true)
{
auto found = q_val.find(',', pos);
ck_tile::index_t n =
_S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos));
v.push_back(n);
if(found == std::string::npos)
{
break;
}
pos = found + 1;
}
return v;
#undef _S2I_
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto shape = decode_vec(arg_parser.get_str("shape"));
auto perm = decode_vec(arg_parser.get_str("perm"));
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
int seed = arg_parser.get_int("seed");
assert(shape.size() == perm.size());
ck_tile::index_t rank = perm.size();
if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks)
{
printf("rank %d permute is not support yet\n", rank);
return false;
}
ck_tile::HostTensor<DataType> x(shape);
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
// std::cout << "@@@@" << tmp << std::endl;
for(int i = 0; i < static_cast<int>(rank); i++)
{
// std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" <<
// static_cast<int>(rank)
// << std::endl;
tmp[i] = shape[perm[i]];
}
// std::cout << "@@@" << tmp << std::endl;
return tmp;
}();
ck_tile::HostTensor<DataType> y(y_shape);
ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes());
x_buf.ToDevice(x.data());
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm
<< std::endl;
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
stream_warmup,
stream_repeat};
float ave_time = 0.f;
auto run_permute = [&]() {
permute_traits t;
t.data_type = data_type;
permute_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.rank = rank;
std::copy(shape.begin(), shape.end(), a.shape);
std::copy(perm.begin(), perm.end(), a.perm);
return permute(t, a, stream_config);
};
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")))
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
auto nr = shape[1];
auto nw = shape[2];
auto kr = shape[3];
auto kw = shape[4];
auto kv = shape[5];
a.n = nr * nw;
a.k = kr * kw * kv;
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
else
{
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
a.n = shape[1] * shape[2] * shape[3];
a.k = shape[4] * shape[5] * shape[6];
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
shape[4] % 8 == 0 && shape[1] % 2 == 0)
{
// 32x32x8 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
shape[4] % 4 == 0 && shape[1] % 4 == 0)
{
// 16x16x16 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,4x,4x,4,4,16,8
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
}
else
#endif
{
ave_time = run_permute();
}
std::cout << ", time:" << ave_time << "ms" << std::flush;
bool pass = true;
if(do_validation)
{
reference_permute(x, y, perm);
ck_tile::HostTensor<DataType> y_dev(y.get_lengths());
y_buf.FromDevice(y_dev.data());
pass = std::equal(
y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) {
using itype = to_integer_type<sizeof(DataType)>;
itype i_d = ck_tile::bit_cast<itype>(d);
itype i_h = ck_tile::bit_cast<itype>(h);
return i_d == i_h;
});
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::endl;
return pass;
}
template <typename DataType>
bool run_test_case(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
return run<DataType>(arg_parser);
}
template <typename DataType>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
constexpr int num_args = 6;
char* argv[num_args];
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
assert(test_cases[test_idx].size() == num_args &&
"invalid number of arguments in test case");
for(int arg_idx = 0; arg_idx < num_args; ++arg_idx)
{
argv[arg_idx] = test_cases[test_idx][arg_idx].data();
}
valid = valid && run_test_case<DataType>(num_args, argv);
if(!valid)
break;
}
return valid;
}
std::vector<std::vector<std::string>> create_test_cases(const std::string prec)
{
return {
{"-prec=" + prec, "-shape=3,8", "-perm=1,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=48,6,8", "-perm=2,1,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=24,128,3", "-perm=0,2,1", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=4,10,7,6", "-perm=0,2,3,1", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=8,24,36,10", "-perm=3,1,2,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec, "-shape=8,1,36,4", "-perm=2,1,0,3", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=" + prec,
"-shape=5,10,16,2,36,4",
"-perm=4,5,2,1,0,3",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=" + prec,
"-shape=2,32,8,3,6,2,5,4",
"-perm=5,2,4,7,1,6,3,0",
"-v=1",
"-warmup=0",
"-repeat=1"}};
}
std::vector<std::vector<std::string>> create_test_cases_fp16()
{
return {{"-prec=fp16",
"-shape=3,6,4,32,16,2,8",
"-perm=0,1,4,2,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=5,10,4,32,8,2,8",
"-perm=0,1,4,2,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=3,8,4,16,16,4,8",
"-perm=0,1,4,2,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=3,6,4,32,16,2,8",
"-perm=0,1,2,4,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=5,10,4,32,8,2,8",
"-perm=0,1,2,4,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=3,8,4,16,16,4,8",
"-perm=0,1,2,4,5,3,6",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=2,8,16,8,4,8",
"-perm=0,1,3,4,2,5",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=1,24,32,16,2,8",
"-perm=0,1,3,4,2,5",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16", "-shape=3,8", "-perm=1,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=48,6,8", "-perm=2,1,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=24,128,3", "-perm=0,2,1", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=4,10,7,6", "-perm=0,2,3,1", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=8,24,36,10", "-perm=3,1,2,0", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16", "-shape=8,1,36,4", "-perm=2,1,0,3", "-v=1", "-warmup=0", "-repeat=1"},
{"-prec=fp16",
"-shape=5,10,16,2,36,4",
"-perm=4,5,2,1,0,3",
"-v=1",
"-warmup=0",
"-repeat=1"},
{"-prec=fp16",
"-shape=2,32,8,3,6,2,5,4",
"-perm=5,2,4,7,1,6,3,0",
"-v=1",
"-warmup=0",
"-repeat=1"}};
}

View File

@@ -0,0 +1,14 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_permute_types.hpp"
#include "test_permute_util.hpp"
#include "gtest/gtest.h"
#define TEST_SUITE_NAME TestCkTilePermute
TYPED_TEST_SUITE(TestCkTilePermute, KernelTypesPermute);
#include "test_permute_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,279 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#ifndef TEST_PERMUTE_CASES_INC
#define TEST_PERMUTE_CASES_INC
TYPED_TEST(TEST_SUITE_NAME, Permute1)
{
std::vector<ck_tile::index_t> shape{3, 8};
std::string perm{"1,0"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute2)
{
std::vector<ck_tile::index_t> shape{48, 6, 8};
std::string perm{"2,1,0"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute3)
{
std::vector<ck_tile::index_t> shape{24, 128, 3};
std::string perm{"0,2,1"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute4)
{
std::vector<ck_tile::index_t> shape{4, 10, 7, 6};
std::string perm{"0,2,3,1"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute5)
{
std::vector<ck_tile::index_t> shape{8, 24, 36, 10};
std::string perm{"3,1,2,0"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute6)
{
std::vector<ck_tile::index_t> shape{8, 1, 36, 4};
std::string perm{"2,1,0,3"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute7)
{
std::vector<ck_tile::index_t> shape{5, 10, 16, 2, 36, 4};
std::string perm{"4,5,2,1,0,3"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute8)
{
std::vector<ck_tile::index_t> shape{2, 32, 8, 3, 6, 2, 5, 4};
std::string perm{"5,2,4,7,1,6,3,0"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute9)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{3, 6, 4, 32, 16, 2, 8};
std::string perm{"0,1,4,2,5,3,6"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute10)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{5, 10, 4, 32, 8, 2, 8};
std::string perm{"0,1,4,2,5,3,6"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute11)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{3, 8, 4, 16, 16, 4, 8};
std::string perm{"0,1,4,2,5,3,6"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute12)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{3, 6, 4, 32, 16, 2, 8};
std::string perm{"0,1,2,4,5,3,6"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute13)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{5, 10, 4, 32, 8, 2, 8};
std::string perm{"0,1,2,4,5,3,6"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute14)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{3, 8, 4, 16, 16, 4, 8};
std::string perm{"0,1,2,4,5,3,6"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute15)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{2, 8, 16, 8, 4, 8};
std::string perm{"0,1,3,4,2,5"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute16)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{1, 24, 32, 16, 2, 8};
std::string perm{"0,1,3,4,2,5"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute17)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{3, 8};
std::string perm{"1,0"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute18)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{48, 6, 8};
std::string perm{"2,1,0"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute19)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{24, 128, 3};
std::string perm{"0,2,1"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute20)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{4, 10, 7, 6};
std::string perm{"0,2,3,1"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute21)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{8, 24, 36, 10};
std::string perm{"3,1,2,0"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute22)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{8, 1, 36, 4};
std::string perm{"2,1,0,3"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute23)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{5, 10, 16, 2, 36, 4};
std::string perm{"4,5,2,1,0,3"};
this->Run(shape, perm);
}
TYPED_TEST(TEST_SUITE_NAME, Permute24)
{
if constexpr(!std::is_same_v<TypeParam, F16Types>)
{
GTEST_SKIP() << "Skipping this test: Only run with fp16";
}
std::vector<ck_tile::index_t> shape{2, 32, 8, 3, 6, 2, 5, 4};
std::string perm{"5,2,4,7,1,6,3,0"};
this->Run(shape, perm);
}
#endif

View File

@@ -0,0 +1,10 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "ck_tile/host.hpp"
#include "gtest/gtest.h"
using F16Types = std::tuple<ck_tile::fp16_t>;
using KernelTypesPermute =
::testing::Types<F16Types, std::tuple<float>, std::tuple<ck_tile::fp8_t>>;

View File

@@ -0,0 +1,328 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include <array>
#include <cassert>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
namespace detail {
template <int bytes>
struct to_integer_type;
template <>
struct to_integer_type<4>
{
using type = int32_t;
};
template <>
struct to_integer_type<2>
{
using type = int16_t;
};
template <>
struct to_integer_type<1>
{
using type = int8_t;
};
} // namespace detail
template <int bytes>
using to_integer_type = typename detail::to_integer_type<bytes>::type;
// host API (should come from codegen)
template <typename DataType>
float permute(permute_args a, const ck_tile::stream_config& s)
{
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
// "1,2,3,4" -> vector{1,2,3,4}
std::vector<ck_tile::index_t> decode_vec(std::string q_val)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
std::string::size_type pos = 0;
std::vector<ck_tile::index_t> v;
while(true)
{
auto found = q_val.find(',', pos);
ck_tile::index_t n =
_S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos));
v.push_back(n);
if(found == std::string::npos)
{
break;
}
pos = found + 1;
}
return v;
#undef _S2I_
}
template <typename Tuple>
class TestCkTilePermute : public ::testing::Test
{
protected:
using DataType = std::tuple_element_t<0, Tuple>;
void Run(std::vector<ck_tile::index_t>& shape, std::string& perm)
{
std::string data_type = get_precision_string();
std::vector<ck_tile::index_t> perm_vec = decode_vec(perm);
int seed = 11939;
assert(shape.size() == perm_vec.size());
ck_tile::index_t rank = perm_vec.size();
if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks)
{
printf("rank %d permute is not support yet\n", rank);
EXPECT_TRUE(false);
}
ck_tile::HostTensor<DataType> x(shape);
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
for(int i = 0; i < static_cast<int>(rank); i++)
{
tmp[i] = shape[perm_vec[i]];
}
return tmp;
}();
ck_tile::HostTensor<DataType> y(y_shape);
ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes());
x_buf.ToDevice(x.data());
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape
<< ", permute:" << perm_vec << std::endl;
ck_tile::stream_config stream_config{nullptr, false, 0, 0, 1};
auto run_permute = [&]() {
permute_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.rank = rank;
std::copy(shape.begin(), shape.end(), a.shape);
std::copy(perm_vec.begin(), perm_vec.end(), a.perm);
return permute<DataType>(a, stream_config);
};
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((perm == std::string("0,1,4,2,5,3,6") || perm == std::string("0,1,2,4,5,3,6") ||
perm == std::string("0,1,3,4,2,5")))
{
if(perm == std::string("0,1,3,4,2,5"))
{
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.permute = perm;
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
auto nr = shape[1];
auto nw = shape[2];
auto kr = shape[3];
auto kw = shape[4];
auto kv = shape[5];
a.n = nr * nw;
a.k = kr * kw * kv;
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
matrix_core_swizzle<DataType>(t, a, stream_config);
}
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
matrix_core_swizzle<DataType>(t, a, stream_config);
}
else
{
run_permute();
}
}
else
{
matrix_core_swizzle_traits t;
t.permute = perm;
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
a.n = shape[1] * shape[2] * shape[3];
a.k = shape[4] * shape[5] * shape[6];
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
shape[4] % 8 == 0 && shape[1] % 2 == 0)
{
// 32x32x8 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
matrix_core_swizzle<DataType>(t, a, stream_config);
}
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
shape[4] % 4 == 0 && shape[1] % 4 == 0)
{
// 16x16x16 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,4x,4x,4,4,16,8
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
matrix_core_swizzle<DataType>(t, a, stream_config);
}
else
{
run_permute();
}
}
}
else
#endif
{
run_permute();
}
bool pass = true;
// Do Validation
reference_permute(x, y, perm_vec);
ck_tile::HostTensor<DataType> y_dev(y.get_lengths());
y_buf.FromDevice(y_dev.data());
pass = std::equal(
y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) {
using itype = to_integer_type<sizeof(DataType)>;
itype i_d = ck_tile::bit_cast<itype>(d);
itype i_h = ck_tile::bit_cast<itype>(h);
return i_d == i_h;
});
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
std::cout << std::endl;
EXPECT_TRUE(pass);
}
static std::string get_precision_string()
{
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
{
return "fp16";
}
else if(std::is_same_v<DataType, ck_tile::fp8_t>)
{
return "fp8";
}
else if(std::is_same_v<DataType, float>)
{
return "fp32";
}
else
{
throw std::runtime_error("invalid precision");
}
}
};

View File

@@ -3,7 +3,7 @@ if(GPU_TARGETS MATCHES "gfx9")
function (add_smoothquant_test TARGET_NAME MAIN_SRC)
message(DEBUG "adding ${TARGET_NAME}")
add_test_executable(${TARGET_NAME} ${MAIN_SRC})
add_gtest_executable(${TARGET_NAME} ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
@@ -20,8 +20,7 @@ if(GPU_TARGETS MATCHES "gfx9")
endfunction(add_smoothquant_test TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_test(test_ck_tile_smoothquant_fp16 smoothquant_fp16.cpp ${INSTANCE_SRCS})
add_smoothquant_test(test_ck_tile_smoothquant_bf16 smoothquant_bf16.cpp ${INSTANCE_SRCS})
add_smoothquant_test(test_ck_tile_smoothquant test_smoothquant.cpp ${INSTANCE_SRCS})
else()
message(DEBUG "Skipping ck_tile smoothquant tests for current target")

View File

@@ -22,9 +22,7 @@ using trait_ = smoothquant_traits_<DataType_,
kTwoPass_>;
template <typename data_type>
float smoothquant_dispatch(smoothquant_traits /*t*/,
smoothquant_args a,
const ck_tile::stream_config& s)
float smoothquant_dispatch(smoothquant_args a, const ck_tile::stream_config& s)
{
float r = -1;
// clang-format off
@@ -128,16 +126,14 @@ float smoothquant_dispatch(smoothquant_traits /*t*/,
// clang-format on
}
float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::stream_config& s)
template <>
float smoothquant<ck_tile::fp16_t>(smoothquant_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
return smoothquant_dispatch<ck_tile::fp16_t>(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return smoothquant_dispatch<ck_tile::bf16_t>(t, a, s);
}
else
throw std::runtime_error("Without supported instances!");
return smoothquant_dispatch<ck_tile::fp16_t>(a, s);
}
template <>
float smoothquant<ck_tile::bf16_t>(smoothquant_args a, const ck_tile::stream_config& s)
{
return smoothquant_dispatch<ck_tile::bf16_t>(a, s);
}

View File

@@ -111,4 +111,5 @@ struct smoothquant_traits
std::string data_type;
};
float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);
template <typename DataType>
float smoothquant(smoothquant_args, const ck_tile::stream_config&);

View File

@@ -1,273 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "smoothquant.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[], int index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv, index);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType;
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
smscale_buf.ToDevice(smscale_host.data());
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", y_stride:" << y_stride << std::flush;
smoothquant_traits traits{data_type};
smoothquant_args args{x_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
m,
n,
x_stride,
y_stride};
float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n +
sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier
{
auto f = [&](auto n_) {
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_)
{
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_smscale;
}
};
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency());
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(y_stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * y_stride +
n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
std::vector<std::vector<std::string>> create_test_cases(const std::string prec)
{
return {{"-prec=" + prec, "-m=99", "-n=13", "-x_stride=-1"},
{"-prec=" + prec, "-m=17", "-n=16", "-x_stride=-1"},
{"-prec=" + prec, "-m=1", "-n=100", "-x_stride=-1"},
{"-prec=" + prec, "-m=4", "-n=128", "-x_stride=-1"},
{"-prec=" + prec, "-m=80", "-n=127", "-x_stride=-1"},
{"-prec=" + prec, "-m=22", "-n=255", "-x_stride=256"},
{"-prec=" + prec, "-m=7", "-n=599", "-x_stride=-1"},
{"-prec=" + prec, "-m=19", "-n=512", "-x_stride=-1"},
{"-prec=" + prec, "-m=33", "-n=313", "-x_stride=1000"},
{"-prec=" + prec, "-m=11", "-n=510", "-x_stride=-1"},
{"-prec=" + prec, "-m=171", "-n=676", "-x_stride=818"},
{"-prec=" + prec, "-m=91", "-n=636", "-x_stride=-1"},
{"-prec=" + prec, "-m=12", "-n=768", "-x_stride=800"},
{"-prec=" + prec, "-m=100", "-n=766", "-x_stride=812"},
{"-prec=" + prec, "-m=31", "-n=1024", "-x_stride=-1"},
{"-prec=" + prec, "-m=64", "-n=1000", "-x_stride=1004"},
{"-prec=" + prec, "-m=8", "-n=1501", "-x_stride=-1"},
{"-prec=" + prec, "-m=3", "-n=1826", "-x_stride=-1"},
{"-prec=" + prec, "-m=5", "-n=2040", "-x_stride=-1"},
{"-prec=" + prec, "-m=7", "-n=2734", "-x_stride=-1"},
{"-prec=" + prec, "-m=1", "-n=3182", "-x_stride=-1"},
{"-prec=" + prec, "-m=9", "-n=4096", "-x_stride=-1"},
{"-prec=" + prec, "-m=3", "-n=8192", "-x_stride=-1"},
{"-prec=" + prec, "-m=1", "-n=10547", "-x_stride=-1"},
{"-prec=" + prec, "-m=3", "-n=17134", "-x_stride=-1"}};
}
template <typename DataType>
bool run_test_case(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
return run<DataType>(arg_parser);
}
template <typename DataType>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
constexpr int num_args = 4;
char* argv[num_args];
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
assert(test_cases[test_idx].size() == num_args &&
"invalid number of arguments in test case");
for(std::size_t idx = 0; idx < num_args; ++idx)
{
argv[idx] = test_cases[test_idx][idx].data();
}
valid = valid && run_test_case<DataType>(num_args, argv);
if(!valid)
break;
}
return valid;
}

View File

@@ -1,11 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("bf16");
return !run_test_cases<ck_tile::bf16_t>(test_cases);
}

View File

@@ -1,11 +0,0 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("fp16");
return !run_test_cases<ck_tile::half_t>(test_cases);
}

View File

@@ -0,0 +1,14 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_smoothquant_types.hpp"
#include "test_smoothquant_util.hpp"
#include "gtest/gtest.h"
#define TEST_SUITE_NAME TestCkTileSmoothquant
TYPED_TEST_SUITE(TestCkTileSmoothquant, KernelTypesSmoothquant);
#include "test_smoothquant_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,206 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#ifndef TEST_SMOOTHQUANT_CASES_INC
#define TEST_SMOOTHQUANT_CASES_INC
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m99_n13)
{
ck_tile::index_t m = 99;
ck_tile::index_t n = 13;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m17_n16)
{
ck_tile::index_t m = 17;
ck_tile::index_t n = 16;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m1_n100)
{
ck_tile::index_t m = 1;
ck_tile::index_t n = 100;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m4_n128)
{
ck_tile::index_t m = 4;
ck_tile::index_t n = 128;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m80_n127)
{
ck_tile::index_t m = 80;
ck_tile::index_t n = 127;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m22_n255)
{
ck_tile::index_t m = 22;
ck_tile::index_t n = 255;
ck_tile::index_t x_stride = 256;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m7_n599)
{
ck_tile::index_t m = 7;
ck_tile::index_t n = 599;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m33_n313)
{
ck_tile::index_t m = 33;
ck_tile::index_t n = 313;
ck_tile::index_t x_stride = 1000;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m11_n510)
{
ck_tile::index_t m = 11;
ck_tile::index_t n = 510;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m171_n676)
{
ck_tile::index_t m = 171;
ck_tile::index_t n = 676;
ck_tile::index_t x_stride = 818;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m91_n636)
{
ck_tile::index_t m = 91;
ck_tile::index_t n = 636;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m12_n768)
{
ck_tile::index_t m = 12;
ck_tile::index_t n = 768;
ck_tile::index_t x_stride = 800;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m100_n766)
{
ck_tile::index_t m = 100;
ck_tile::index_t n = 766;
ck_tile::index_t x_stride = 812;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m31_n1024)
{
ck_tile::index_t m = 31;
ck_tile::index_t n = 1024;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m64_n1000)
{
ck_tile::index_t m = 64;
ck_tile::index_t n = 1000;
ck_tile::index_t x_stride = 1004;
this->Run(m, n, x_stride);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m8_n1501)
{
ck_tile::index_t m = 8;
ck_tile::index_t n = 1501;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m3_n1826)
{
ck_tile::index_t m = 3;
ck_tile::index_t n = 1826;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m5_n2040)
{
ck_tile::index_t m = 5;
ck_tile::index_t n = 2040;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m7_n2734)
{
ck_tile::index_t m = 7;
ck_tile::index_t n = 2734;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m1_n3182)
{
ck_tile::index_t m = 1;
ck_tile::index_t n = 3182;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m9_n4096)
{
ck_tile::index_t m = 9;
ck_tile::index_t n = 4096;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m3_n8192)
{
ck_tile::index_t m = 3;
ck_tile::index_t n = 8192;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m1_n10547)
{
ck_tile::index_t m = 1;
ck_tile::index_t n = 10547;
this->Run(m, n);
}
TYPED_TEST(TEST_SUITE_NAME, Smoothqauant_m3_n17134)
{
ck_tile::index_t m = 3;
ck_tile::index_t n = 17134;
this->Run(m, n);
}
#endif

View File

@@ -0,0 +1,9 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "ck_tile/host.hpp"
#include "gtest/gtest.h"
using KernelTypesSmoothquant =
::testing::Types<std::tuple<ck_tile::fp16_t>, std::tuple<ck_tile::bf16_t>>;

View File

@@ -0,0 +1,181 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "smoothquant.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
template <typename Tuple>
class TestCkTileSmoothquant : public ::testing::Test
{
protected:
using DataType = std::tuple_element_t<0, Tuple>;
void Run(ck_tile::index_t m,
ck_tile::index_t n,
ck_tile::index_t x_stride = -1,
ck_tile::index_t y_stride = -1)
{
if(x_stride < 0)
x_stride = n;
if(y_stride < 0)
y_stride = n;
assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType;
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
smscale_buf.ToDevice(smscale_host.data());
std::cout << "m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", y_stride:" << y_stride << std::flush;
smoothquant_args args{x_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
m,
n,
x_stride,
y_stride};
smoothquant<DataType>(args, ck_tile::stream_config{nullptr, false});
bool pass = true;
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier
{
auto f = [&](auto n_) {
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_)
{
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_smscale;
}
};
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency());
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(y_stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * y_stride +
n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
EXPECT_TRUE(pass);
}
};