mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add a convinvscale op, related instances and examples (#1307)
* Update the element op * Add an example * Add instances * Add a client example * make sure new instances only build on gfx9 * Update element op and its handling * Format * Update instances to take element op as an argument * Update examples to use random scale values * Format * Update client example with random scales * Format --------- Co-authored-by: illsilin <Illia.Silin@amd.com>
This commit is contained in:
@@ -528,26 +528,6 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
/// @brief Op to multiply convolution results by inverted scale factors
|
||||
/// @param e Output after scaling
|
||||
/// @param c Convolution result
|
||||
/// @param d0 Input scale factor
|
||||
/// @param d1 Weights scale factor
|
||||
/// @param d2 Output scale factor
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, float, float, float, float>(
|
||||
f8_t& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
e = type_convert<f8_t>(c / d0 / d1 / d2);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -961,6 +961,29 @@ struct Elu
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
__host__ __device__ ConvInvscale(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename E, typename C>
|
||||
__host__ __device__ void operator()(E& e, const C& c) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
|
||||
{
|
||||
e = type_convert<f8_t>(c / scale_in_ / scale_wei_ / scale_out_);
|
||||
};
|
||||
|
||||
float scale_in_;
|
||||
float scale_wei_;
|
||||
float scale_out_;
|
||||
};
|
||||
|
||||
struct ConvScale
|
||||
{
|
||||
__host__ __device__ ConvScale(float scale_in = 1.f,
|
||||
|
||||
Reference in New Issue
Block a user