add fp8 as dst (#1830)

This commit is contained in:
carlushuang
2025-01-22 17:34:27 +08:00
committed by GitHub
parent 1fe2c35291
commit 052a72655c
30 changed files with 300 additions and 194 deletions

View File

@@ -8,26 +8,13 @@
#include "ck_tile/ops/smoothquant.hpp"
#include <string>
template <typename DataType>
struct MoeSmoothquantTypeConfig;
template <>
struct MoeSmoothquantTypeConfig<ck_tile::half_t>
template <typename InputType, typename OutputType>
struct MoeSmoothquantTypeConfig
{
using XDataType = ck_tile::half_t;
using XDataType = InputType;
using SmoothScaleDataType = float;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
template <>
struct MoeSmoothquantTypeConfig<ck_tile::bf16_t>
{
using XDataType = ck_tile::bf16_t;
using SmoothScaleDataType = float;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using QYDataType = OutputType;
using ComputeDataType = float;
};
@@ -37,7 +24,8 @@ struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
template <typename InputType_,
typename OutputType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
@@ -47,7 +35,8 @@ template <typename DataType_,
bool kTwoPass_>
struct moe_smoothquant_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
using InputType = ck_tile::remove_cvref_t<InputType_>;
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
@@ -108,7 +97,8 @@ float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
// This is the public API, will be generated by script
struct moe_smoothquant_traits
{
std::string data_type;
std::string in_type; // input type
std::string out_type; // output type
};
float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);