mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
@@ -51,16 +51,18 @@ struct composes<F>
|
||||
template <typename... Ts>
|
||||
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
|
||||
|
||||
template <typename To>
|
||||
template <typename SaturateType>
|
||||
struct saturates
|
||||
{
|
||||
template <typename From>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const From& from) const
|
||||
-> std::enable_if_t<std::is_arithmetic_v<From>, From>
|
||||
// NOTE: this function does not return SaturateType value
|
||||
// it is user's responsiblity to do further cast or not
|
||||
template <typename AccType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const
|
||||
-> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
|
||||
{
|
||||
return clamp(from,
|
||||
type_convert<From>(numeric<To>::lowest()),
|
||||
type_convert<From>(numeric<To>::max()));
|
||||
return clamp(a_,
|
||||
type_convert<AccType>(numeric<SaturateType>::lowest()),
|
||||
type_convert<AccType>(numeric<SaturateType>::max()));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>&
|
||||
// scale = amax / 127 for int8
|
||||
auto v_scale = type_convert<XDataType>(scale_m(m));
|
||||
auto v_qx = v_x / v_scale;
|
||||
qx_m_n(m, n) = saturates<QXDataType>{}(v_qx);
|
||||
qx_m_n(m, n) = type_convert<QXDataType>(saturates<QXDataType>{}(v_qx));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -101,6 +101,7 @@ struct MoeSmoothquant
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "i8"; };
|
||||
// clang-format on
|
||||
|
||||
// in byte
|
||||
@@ -118,7 +119,7 @@ struct MoeSmoothquant
|
||||
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" +
|
||||
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" + _SS_(t2s<QYDataType>::name) + "_" +
|
||||
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
|
||||
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
|
||||
_SS_(Pipeline::name) + surfix;
|
||||
|
||||
@@ -113,7 +113,7 @@ struct SmoothquantPipelineOnePass
|
||||
sweep_tile(qy, [&](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
auto qy_ = y[idx] / yscale[i_idx];
|
||||
qy(idx) = saturates<QYDataType>{}(qy_);
|
||||
qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
|
||||
});
|
||||
store_tile(qy_window, qy);
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ struct SmoothquantPipelineTwoPass
|
||||
sweep_tile(qy, [&](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
auto qy_ = y[idx] / yscale[i_idx];
|
||||
qy(idx) = saturates<QYDataType>{}(qy_);
|
||||
qy(idx) = type_convert<QYDataType>(saturates<QYDataType>{}(qy_));
|
||||
});
|
||||
store_tile(qy_window, qy);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user