add fp8 as dst (#1830)

This commit is contained in:
carlushuang
2025-01-22 17:34:27 +08:00
committed by GitHub
parent 1fe2c35291
commit 052a72655c
30 changed files with 300 additions and 194 deletions

View File

@@ -63,7 +63,8 @@ auto create_args(int argc, char* argv[])
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("prec_i", "fp16", "input precision, fp16/bf16")
.insert("prec_o", "int8", "precision, int8/fp8")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
@@ -71,7 +72,7 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
template <typename InputType, typename OutputType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t tokens = arg_parser.get_int("t");
@@ -81,7 +82,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride = hidden_size;
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
std::string data_type = arg_parser.get_str("prec");
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
@@ -89,7 +91,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
assert(stride >= hidden_size);
using TypeConfig = MoeSmoothquantTypeConfig<DataType>;
using TypeConfig = MoeSmoothquantTypeConfig<InputType, OutputType>;
using XDataType = typename TypeConfig::XDataType;
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
@@ -122,11 +124,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
smscale_buf.ToDevice(smscale_host.data());
topk_ids_buf.ToDevice(topk_ids_host.data());
std::cout << "[" << data_type << "]"
std::cout << "[" << prec_i << "-" << prec_o << "]"
<< " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride
<< ", experts:" << experts << ", topk:" << topk << std::flush;
moe_smoothquant_traits traits{data_type};
moe_smoothquant_traits traits{prec_i, prec_o};
moe_smoothquant_args args{x_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
@@ -251,14 +253,23 @@ int main(int argc, char* argv[])
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
const std::string prec_i = arg_parser.get_str("prec_i");
const std::string prec_o = arg_parser.get_str("prec_o");
if(prec_i == "fp16" && prec_o == "int8")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
return run<ck_tile::half_t, ck_tile::int8_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16")
else if(prec_i == "fp16" && prec_o == "fp8")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
return run<ck_tile::half_t, ck_tile::fp8_t>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "int8")
{
return run<ck_tile::bf16_t, ck_tile::int8_t>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "fp8")
{
return run<ck_tile::bf16_t, ck_tile::fp8_t>(arg_parser) ? 0 : -2;
}
return -3;