Revert "Add support for mixed precision in contraction scale and bilinear" (#967)

* Revert "Add support for mixed precision in contraction scale and bilinear (#936)"

This reverts commit f07485060e.

* revert commits #957 and #960
This commit is contained in:
Illia Silin
2023-10-05 14:58:23 -07:00
committed by GitHub
parent 570ff3ddbe
commit 4daedf8ca5
99 changed files with 1961 additions and 7536 deletions

View File

@@ -186,25 +186,6 @@ struct Bilinear
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x0_tmp = type_convert<float>(x0);
const float x1_tmp = type_convert<float>(x1);
const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
y = type_convert<bhalf_t>(y_tmp);
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
y = y_tmp;
};
template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const

View File

@@ -33,12 +33,6 @@ struct PassThrough
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
@@ -75,12 +69,6 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<float, bhalf_t>(float& y, const bhalf_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
@@ -243,20 +231,6 @@ struct Scale
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = ck::type_convert<half_t>(scale_) * x;
};
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
const float x_tmp = ck::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
};
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{