mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add support for mixed precision in contraction scale and bilinear (#973)
* Add support for mixed precision in contraction scale and bilinear (#936) * Extract common functionality to separate files * Reference contraction: Remove incorrect consts from type_converts * Reference contraction: Add missing type_convert for dst value * Reference contraction: Fix incorrect order of B matrix dimensions * Add support for mixed precision in contraction scale and bilinear * Move using statements from instances to a common file * Move using statements from examples to a common file * Fix the order of B matrix dimensions across examples and profiler * Fix the computation of error threshold * Make ComputeDataType an optional argument * Include possible DataType -> ComputeDataType casting error in the threshold * Remove commented code * Make the ComputeDataType an optional argument in instance --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
73743aa0aa
commit
4ef704d8a6
@@ -186,6 +186,25 @@ struct Bilinear
|
||||
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x0_tmp = type_convert<float>(x0);
|
||||
const float x1_tmp = type_convert<float>(x1);
|
||||
const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
|
||||
y = type_convert<bhalf_t>(y_tmp);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
|
||||
y = y_tmp;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
|
||||
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
|
||||
|
||||
@@ -33,6 +33,12 @@ struct PassThrough
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<double, float>(double& y, const float& x) const
|
||||
{
|
||||
y = type_convert<double>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
@@ -69,6 +75,12 @@ struct PassThrough
|
||||
y = type_convert<bhalf_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, bhalf_t>(float& y, const bhalf_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
|
||||
{
|
||||
@@ -225,6 +237,20 @@ struct Scale
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
|
||||
{
|
||||
y = ck::type_convert<half_t>(scale_) * x;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
|
||||
{
|
||||
const float x_tmp = ck::type_convert<float>(x);
|
||||
const float y_tmp = scale_ * x_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user