mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
524 lines
22 KiB
C++
524 lines
22 KiB
C++
#include "ck_tile/host.hpp"
|
|
#include "fused_moegemm.hpp"
|
|
#include <algorithm>
|
|
#include <cstring>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
#include <set>
|
|
|
|
// different threshold for different dtype
|
|
template <typename DataType>
|
|
auto get_elimit()
|
|
{
|
|
double rtol = 1e-2;
|
|
double atol = 1e-2;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
template <>
|
|
auto get_elimit<ck_tile::bf16_t>()
|
|
{
|
|
double rtol = 1e-2;
|
|
double atol = 1e-2;
|
|
return ck_tile::make_tuple(rtol, atol);
|
|
}
|
|
|
|
// mfma_type, 0:32x32, 1:16x16
|
|
// TODO: padding?
|
|
template <typename T>
|
|
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
|
|
{
|
|
static_assert(t.get_lengths().size() == 3);
|
|
int b_ = t.get_lengths()[0];
|
|
int n_ = t.get_lengths()[1];
|
|
int k_ = t.get_lengths()[2];
|
|
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
|
|
{
|
|
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
|
}
|
|
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
|
|
{
|
|
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
|
}
|
|
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
|
|
{
|
|
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
|
}
|
|
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
|
|
{
|
|
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
|
|
std::copy(t.begin(), t.end(), t_view.begin());
|
|
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
|
}
|
|
return t;
|
|
}
|
|
|
|
template <typename IndexType>
|
|
void topid_unique_gen(
|
|
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
|
{
|
|
size_t total_size = topk * tokens;
|
|
std::srand(seed);
|
|
std::set<IndexType> unique_set;
|
|
IndexType current_v;
|
|
for(size_t i = 0; i < total_size; i++)
|
|
{
|
|
if(i % topk == 0)
|
|
{
|
|
unique_set.clear();
|
|
}
|
|
current_v = std::rand() % num_expert;
|
|
while(unique_set.find(current_v) != unique_set.end())
|
|
{
|
|
current_v = std::rand() % num_expert;
|
|
}
|
|
unique_set.insert(current_v);
|
|
host_tensor[i] = current_v;
|
|
}
|
|
}
|
|
|
|
auto create_args(int argc, char* argv[])
|
|
{
|
|
ck_tile::ArgParser arg_parser;
|
|
arg_parser.insert("t", "128", "num input tokens")
|
|
.insert("e", "32", "num of experts")
|
|
.insert("k", "5", "topk")
|
|
.insert("h", "8192", "hidden_size of this model")
|
|
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
|
|
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
|
|
.insert("bm", "32", "blocking factor for sorted tokens")
|
|
.insert("tp", "8", "tensor parallel size")
|
|
.insert("v", "1", "cpu validation or not")
|
|
.insert("kname", "1", "print kernel name or not")
|
|
.insert("prec_i", "bf16", "input precision")
|
|
.insert("prec_w", "bf16", "weight precision")
|
|
.insert("prec_o", "bf16", "output precision")
|
|
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
|
|
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
|
|
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
|
|
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
|
|
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
|
.insert(
|
|
"gate_only", "0", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
|
.insert("balance",
|
|
"1",
|
|
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
|
.insert("warmup", "5", "cold iter")
|
|
.insert("repeat", "20", "hot iter");
|
|
|
|
bool result = arg_parser.parse(argc, argv);
|
|
return std::make_tuple(result, arg_parser);
|
|
}
|
|
|
|
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
|
|
// SQ:smooth-quant-type, KW:topk-weight-type
|
|
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
|
bool run(const ck_tile::ArgParser& arg_parser)
|
|
{
|
|
ck_tile::index_t tokens = arg_parser.get_int("t");
|
|
ck_tile::index_t experts = arg_parser.get_int("e");
|
|
ck_tile::index_t topk = arg_parser.get_int("k");
|
|
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
|
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
|
|
ck_tile::index_t stride = arg_parser.get_int("stride");
|
|
ck_tile::index_t block_m = arg_parser.get_int("bm");
|
|
if(stride < 0)
|
|
stride = hidden_size;
|
|
std::string prec_i = arg_parser.get_str("prec_i");
|
|
std::string prec_w = arg_parser.get_str("prec_w");
|
|
std::string prec_o = arg_parser.get_str("prec_o");
|
|
std::string prec_st = arg_parser.get_str("prec_st");
|
|
std::string prec_sw = arg_parser.get_str("prec_sw");
|
|
std::string prec_sq = arg_parser.get_str("prec_sq");
|
|
std::string prec_kw = arg_parser.get_str("prec_kw");
|
|
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
|
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
|
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
|
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
|
int kname = arg_parser.get_int("kname");
|
|
int do_validation = arg_parser.get_int("v");
|
|
int warmup = arg_parser.get_int("warmup");
|
|
int repeat = arg_parser.get_int("repeat");
|
|
int fused_quant = arg_parser.get_int("fquant");
|
|
int gate_only = arg_parser.get_int("gate_only");
|
|
int balance = arg_parser.get_int("balance");
|
|
int tp = arg_parser.get_int("tp");
|
|
|
|
ck_tile::index_t shared_intermediate_size = intermediate_size * (gate_only ? 1 : 2) / tp;
|
|
|
|
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
|
using ADataType = typename TypeConfig::ADataType;
|
|
using GDataType = typename TypeConfig::GDataType;
|
|
using DDataType = typename TypeConfig::DDataType;
|
|
using AccDataType = typename TypeConfig::AccDataType;
|
|
using ODataType = typename TypeConfig::ODataType;
|
|
using AScaleDataType = typename TypeConfig::AScaleDataType;
|
|
using GScaleDataType = typename TypeConfig::GScaleDataType;
|
|
using DScaleDataType = typename TypeConfig::DScaleDataType;
|
|
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
|
|
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType;
|
|
using IndexDataType = typename TypeConfig::IndexDataType;
|
|
|
|
// host verify
|
|
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
|
|
ck_tile::HostTensor<GDataType> g_host({e, shared_intermediate_size, hidden_size});
|
|
ck_tile::HostTensor<DDataType> d_host({e, intermediate_size, hidden_size});
|
|
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
|
|
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
|
|
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size});
|
|
ck_tile::HostTensor<DScaleDataType> sd_host({intermediate_size});
|
|
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({intermediate_size}); // smooth-quant
|
|
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
|
|
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
|
|
|
|
int max_num_tokens_padded = topk * tokens + experts * (block_m - 1);
|
|
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
|
|
ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
|
|
ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
|
|
{(max_num_tokens_padded + block_m - 1) / block_m});
|
|
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
|
|
|
|
// permute weight
|
|
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w);
|
|
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w);
|
|
|
|
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
|
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_perm_host);
|
|
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_perm_host);
|
|
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f}(sa_host);
|
|
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f}(sg_host);
|
|
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
|
|
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
|
|
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
|
|
|
|
// do moe sorting
|
|
if(balance)
|
|
{
|
|
int e_cnt = 0 for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
|
|
{
|
|
topk_ids_host.mData[i] = e_cnt;
|
|
e_cnt++;
|
|
if(e_cnt >= experts)
|
|
e_cnt = 0;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, experts, 11913);
|
|
}
|
|
|
|
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
|
topk_ids_host,
|
|
topk_weight_host,
|
|
sorted_token_ids_host,
|
|
sorted_weight_host,
|
|
sorted_expert_ids_host,
|
|
num_sorted_tiles_host.mData[0],
|
|
experts,
|
|
block_m);
|
|
// done, preparing GPU buffer
|
|
ck_tile::DeviceMem a_buf(a_host);
|
|
ck_tile::DeviceMem g_perm_buf(g_perm_host);
|
|
ck_tile::DeviceMem d_perm_buf(d_perm_host);
|
|
ck_tile::DeviceMem sa_buf(sa_host);
|
|
ck_tile::DeviceMem sg_buf(sg_host);
|
|
ck_tile::DeviceMem sd_buf(sd_host);
|
|
ck_tile::DeviceMem sy_buf(sy_host);
|
|
ck_tile::DeviceMem o_buf(o_host);
|
|
|
|
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
|
|
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
|
|
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
|
|
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
|
|
|
|
auto prec_str = [&]() {
|
|
auto base_str = prec_i;
|
|
if(prec_i != prec_w)
|
|
base_str += "x" + prec_w;
|
|
if(prec_i != prec_o)
|
|
base_str += "=" + prec_o;
|
|
if(fused_quant != 0)
|
|
{
|
|
base_str += std::string("(") + prec_sa + "|" + prec_sg + "|" + prec_sq + ")";
|
|
}
|
|
return base_str;
|
|
}();
|
|
|
|
std::cout << "[" << prec_str << "]"
|
|
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << ", st:" << stride
|
|
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
|
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
|
|
|
|
fused_moegemm_traits traits{prec_i,
|
|
prec_w,
|
|
prec_o,
|
|
prec_st,
|
|
prec_sw,
|
|
prec_sq,
|
|
prec_kw,
|
|
block_m,
|
|
gate_only,
|
|
fused_quant};
|
|
|
|
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
|
|
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
|
g_buf.GetDeviceBuffer(),
|
|
d_buf.GetDeviceBuffer(),
|
|
fused_quant != 0
|
|
? sg_buf.GetDeviceBuffer(),
|
|
fused_quant != 0
|
|
? sd_buf.GetDeviceBuffer(),
|
|
fused_quant == 1
|
|
? sy_buf.GetDeviceBuffer(),
|
|
o_buf.GetDeviceBuffer(),
|
|
sorted_token_ids_buf.GetDeviceBuffer(),
|
|
sorted_weight_buf.GetDeviceBuffer(),
|
|
sorted_expert_ids_buf.GetDeviceBuffer(),
|
|
num_sorted_tiles_buf.GetDeviceBuffer(),
|
|
hidden_size,
|
|
intermediate_size,
|
|
num_tokens,
|
|
experts,
|
|
stride };
|
|
|
|
float ave_time = fused_moegemm(
|
|
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
|
|
|
if(ave_time < 0)
|
|
{
|
|
std::cout << " not supported!" << std::endl << std::flush;
|
|
return false;
|
|
}
|
|
|
|
#if 0
|
|
std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(GammaDataType) * n +
|
|
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
|
|
|
|
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
|
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
|
#else
|
|
std::size_t flop_gemm_0 = 2 * tokens * topk * shared_intermediate_size * hidden_size;
|
|
std::size_t flop_gemm_1 = 2 * tokens * topk * hidden_size * hidden_size;
|
|
double tflops = (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ave_time) * 1e-3) / 1e12;
|
|
|
|
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
|
std::cout << ", " << ave_time * 1.E3 << " us, " << tflops << " tflops" << std::flush;
|
|
#endif
|
|
bool pass = true;
|
|
|
|
if(do_validation)
|
|
{
|
|
#if 0
|
|
// reference
|
|
if(fused_add != 0)
|
|
{
|
|
// fused pre_add/pre_add_store
|
|
// TODO we accumulate directly to a_host for simplcity here...
|
|
std::transform(a_host.mData.cbegin(),
|
|
a_host.mData.cend(),
|
|
x_residual_host.mData.cbegin(),
|
|
a_host.mData.begin(),
|
|
[](auto x_, auto r_) {
|
|
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
|
|
ck_tile::type_convert<ComputeDataType>(r_);
|
|
return ck_tile::type_convert<ADataType>(o_);
|
|
});
|
|
}
|
|
ck_tile::reference_layernorm2d_fwd<ADataType,
|
|
GammaDataType,
|
|
BetaDataType,
|
|
ComputeDataType,
|
|
YDataType,
|
|
MeanDataType,
|
|
InvStdDataType>(
|
|
a_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
|
|
|
|
if(fused_quant != 0)
|
|
{
|
|
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
|
|
int N_ = acc_.mDesc.get_lengths()[1];
|
|
if(fused_quant == 1)
|
|
{
|
|
for(int n_ = 0; n_ < N_; n_++)
|
|
{
|
|
// input smooth outlier
|
|
acc_(m_, n_) =
|
|
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_));
|
|
}
|
|
}
|
|
ComputeDataType absmax = static_cast<ComputeDataType>(0);
|
|
for(int n_ = 0; n_ < N_; n_++)
|
|
{
|
|
const auto a = ck_tile::abs(acc_(m_, n_));
|
|
absmax = a > absmax ? a : absmax;
|
|
}
|
|
// printf("cpu:absmax:%f\n", absmax);
|
|
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
|
|
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
|
|
for(int n_ = 0; n_ < N_; n_++)
|
|
{
|
|
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
|
|
}
|
|
};
|
|
|
|
ck_tile::reference_layernorm2d_fwd<ADataType,
|
|
GammaDataType,
|
|
BetaDataType,
|
|
ComputeDataType,
|
|
YDataType,
|
|
MeanDataType,
|
|
InvStdDataType>(a_host,
|
|
gamma_host,
|
|
beta_host,
|
|
y_host_ref,
|
|
mean_host_ref,
|
|
invStd_host_ref,
|
|
epsilon,
|
|
dquant_functor);
|
|
}
|
|
else
|
|
{
|
|
ck_tile::reference_layernorm2d_fwd<ADataType,
|
|
GammaDataType,
|
|
BetaDataType,
|
|
ComputeDataType,
|
|
YDataType,
|
|
MeanDataType,
|
|
InvStdDataType>(
|
|
a_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
|
|
}
|
|
|
|
y_buf.FromDevice(y_host_dev.data());
|
|
|
|
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1});
|
|
if(fused_add == 1)
|
|
{
|
|
y_residual_buf.FromDevice(y_residual_host_dev.data());
|
|
}
|
|
|
|
auto [rtol, atol] = get_elimit<InDataType>();
|
|
|
|
if(stride == n)
|
|
{
|
|
pass = ck_tile::check_err(
|
|
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
|
if(fused_add == 1)
|
|
{
|
|
pass &= ck_tile::check_err(y_residual_host_dev,
|
|
a_host,
|
|
std::string("ADD Error: Incorrect results!"),
|
|
rtol,
|
|
atol);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
for(int i_r = 0; i_r < m; i_r++)
|
|
{
|
|
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
|
|
y_host_dev.begin() + i_r * stride + n);
|
|
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
|
|
y_host_ref.begin() + i_r * stride + n);
|
|
pass &= ck_tile::check_err(y_host_dev_row,
|
|
y_host_ref_row,
|
|
std::string("OUT[") + std::to_string(i_r) +
|
|
std::string("] Error: Incorrect results!"),
|
|
rtol,
|
|
atol);
|
|
if(fused_add == 1)
|
|
{
|
|
std::vector<YResidualDataType> y_residual_host_dev_row(
|
|
y_residual_host_dev.begin() + i_r * stride,
|
|
y_residual_host_dev.begin() + i_r * stride + n);
|
|
std::vector<YResidualDataType> y_residual_host_ref_row(
|
|
a_host.begin() + i_r * stride, a_host.begin() + i_r * stride + n);
|
|
pass &= ck_tile::check_err(y_residual_host_dev_row,
|
|
y_residual_host_ref_row,
|
|
std::string("ADD[") + std::to_string(i_r) +
|
|
std::string("] Error: Incorrect results!"),
|
|
rtol,
|
|
atol);
|
|
}
|
|
}
|
|
}
|
|
if(fused_quant == 1)
|
|
{
|
|
y_scale_buf.FromDevice(y_scale_host_dev.data());
|
|
pass &= ck_tile::check_err(y_scale_host_dev,
|
|
y_scale_host_ref,
|
|
std::string("SCALE Error: Incorrect results!"),
|
|
rtol,
|
|
atol);
|
|
}
|
|
|
|
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
|
#else
|
|
std::cout << std::flush << std::endl;
|
|
#endif
|
|
}
|
|
|
|
return pass;
|
|
}
|
|
|
|
int main(int argc, char* argv[])
|
|
{
|
|
auto [result, arg_parser] = create_args(argc, argv);
|
|
if(!result)
|
|
return -1;
|
|
|
|
std::string prec_i = arg_parser.get_str("prec_i");
|
|
std::string prec_o = arg_parser.get_str("prec_o");
|
|
std::string prec_sx = arg_parser.get_str("prec_sx");
|
|
std::string prec_sy = arg_parser.get_str("prec_sy");
|
|
|
|
if(prec_o == "auto")
|
|
{
|
|
prec_o = prec_i;
|
|
}
|
|
if(prec_sx == "auto")
|
|
{
|
|
prec_sx = "fp32";
|
|
}
|
|
if(prec_sy == "auto")
|
|
{
|
|
prec_sy = "fp32";
|
|
}
|
|
int save_mv = arg_parser.get_int("save_mv");
|
|
|
|
// no dynamic quant case
|
|
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32")
|
|
{
|
|
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
|
|
}
|
|
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32")
|
|
{
|
|
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
|
|
}
|
|
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
|
|
{
|
|
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
|
}
|
|
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
|
|
{
|
|
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
|
}
|
|
|
|
// dynamic quant case, only in inference
|
|
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
|
|
{
|
|
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
|
|
}
|
|
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
|
|
{
|
|
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
|
|
}
|
|
|
|
return -3;
|
|
}
|