mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
add fp8 as dst (#1830)
This commit is contained in:
@@ -101,6 +101,7 @@ struct MoeSmoothquant
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "i8"; };
|
||||
// clang-format on
|
||||
|
||||
// in byte
|
||||
@@ -118,7 +119,7 @@ struct MoeSmoothquant
|
||||
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" +
|
||||
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" + _SS_(t2s<QYDataType>::name) + "_" +
|
||||
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
|
||||
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
|
||||
_SS_(Pipeline::name) + surfix;
|
||||
|
||||
@@ -113,7 +113,7 @@ struct SmoothquantPipelineOnePass
|
||||
sweep_tile(qy, [&](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
auto qy_ = y[idx] / yscale[i_idx];
|
||||
qy(idx) = saturates<QYDataType>{}(qy_);
|
||||
qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
|
||||
});
|
||||
store_tile(qy_window, qy);
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ struct SmoothquantPipelineTwoPass
|
||||
sweep_tile(qy, [&](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
auto qy_ = y[idx] / yscale[i_idx];
|
||||
qy(idx) = saturates<QYDataType>{}(qy_);
|
||||
qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
|
||||
});
|
||||
store_tile(qy_window, qy);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user