mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add a scale op, related instances and examples (#1242)
* Add a scale op * Update the element op * Add instances * Add an example * Add a client example * Add a flag check * Revert flag check addition * Fix flag check * Update d strides in example * Update d strides in client example * Apply suggestions from code review Update copyright header Co-authored-by: Bartłomiej Kocot <barkocot@amd.com> * Move the example * Move the client example * Update element op * Update example with the new element op * Add scalar layout * Update example * Update kernel for scalar Ds * Revert kernel changes * Update element op * Update example to use scales' pointers * Format * Update instances * Update client example * Move element op to unary elements * Update element op to work with values instead of pointers * Update instances to take element op as an argument * Update examples to use random scale values --------- Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -961,6 +961,29 @@ struct Elu
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
struct ConvScale
|
||||
{
|
||||
__host__ __device__ ConvScale(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename E, typename C>
|
||||
__host__ __device__ void operator()(E& e, const C& c) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
|
||||
{
|
||||
e = type_convert<f8_t>(c * scale_in_ * scale_wei_ * scale_out_);
|
||||
};
|
||||
|
||||
float scale_in_;
|
||||
float scale_wei_;
|
||||
float scale_out_;
|
||||
};
|
||||
|
||||
// support fastconvert of int8 to fp16
|
||||
|
||||
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
|
||||
|
||||
Reference in New Issue
Block a user