mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
add fp8 as dst (#1830)
This commit is contained in:
@@ -8,26 +8,13 @@
|
||||
#include "ck_tile/ops/smoothquant.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename DataType>
|
||||
struct MoeSmoothquantTypeConfig;
|
||||
|
||||
template <>
|
||||
struct MoeSmoothquantTypeConfig<ck_tile::half_t>
|
||||
template <typename InputType, typename OutputType>
|
||||
struct MoeSmoothquantTypeConfig
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using XDataType = InputType;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MoeSmoothquantTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using QYDataType = OutputType;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
@@ -37,7 +24,8 @@ struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
|
||||
};
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename DataType_,
|
||||
template <typename InputType_,
|
||||
typename OutputType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
@@ -47,7 +35,8 @@ template <typename DataType_,
|
||||
bool kTwoPass_>
|
||||
struct moe_smoothquant_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
using InputType = ck_tile::remove_cvref_t<InputType_>;
|
||||
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
@@ -108,7 +97,8 @@ float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
|
||||
// This is the public API, will be generated by script
|
||||
struct moe_smoothquant_traits
|
||||
{
|
||||
std::string data_type;
|
||||
std::string in_type; // input type
|
||||
std::string out_type; // output type
|
||||
};
|
||||
|
||||
float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);
|
||||
|
||||
Reference in New Issue
Block a user