add fp8 as dst (#1830)

This commit is contained in:
carlushuang
2025-01-22 17:34:27 +08:00
committed by GitHub
parent 1fe2c35291
commit 052a72655c
30 changed files with 300 additions and 194 deletions

View File

@@ -101,6 +101,7 @@ struct MoeSmoothquant
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "i8"; };
// clang-format on
// in byte
@@ -118,7 +119,7 @@ struct MoeSmoothquant
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" +
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" + _SS_(t2s<QYDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;

View File

@@ -113,7 +113,7 @@ struct SmoothquantPipelineOnePass
sweep_tile(qy, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
auto qy_ = y[idx] / yscale[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_);
qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
});
store_tile(qy_window, qy);
}

View File

@@ -136,7 +136,7 @@ struct SmoothquantPipelineTwoPass
sweep_tile(qy, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
auto qy_ = y[idx] / yscale[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_);
qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
});
store_tile(qy_window, qy);