diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 798c5580a2..5499689c9b 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -528,6 +528,26 @@ struct UnaryTypeConvert } }; +struct ConvInvscale +{ + /// @brief Op to multiply convolution results by inverted scale factors + /// @param e Output after scaling + /// @param c Convolution result + /// @param d0 Input scale factor + /// @param d1 Weights scale factor + /// @param d2 Output scale factor + template + __host__ __device__ void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + + template <> + __host__ __device__ void operator()( + f8_t& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + e = type_convert(c / d0 / d1 / d2); + }; +}; + } // namespace element_wise } // namespace tensor_operation } // namespace ck