mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Adding remaining conv, dynamic_op, and scaleadd_scaleadd_relu flavors for grouped conv fwd (#3529)
* Adding remaining flavors for grouped conv fwd As titled. Following variants are added: - grouped_conv2d_fwd_dynamic_op - grouped_conv3d_fwd_dynamic_op - grouped_conv3d_fwd_bilinear - grouped_conv3d_fwd_convscale - grouped_conv3d_fwd_convinvscale - grouped_conv3d_fwd_convscale_add - grouped_conv3d_fwd_convscale_relu - grouped_conv3d_fwd_scale - grouped_conv3d_fwd_combconvscale - grouped_conv3d_fwd_scaleadd_scaleadd_relu * Fix incomplete parsing of types from source names in add_instance_library() cmakelists function so we don't build f8 on RDNA3. * Do not build f8 / bf8 only flavor tests on RDNA3 * Make sure we have proper generic instances for all instance lists related to the post-ces extra flavors, with scalarPerVector = 1. Then disable all but one generic instance per instance list to reduce compile time. * Post rebase fix: Template parameters for Grouped Conv Fwd Device Impl got tweaked upstream. * adding int8 and fp16 overloads to the elementwise operations * fixed copilot nits * Addressing review comments: - removed unnecessary examples for dynamic op - removed unnecessary conv specalizations for all the flavors - removed spurious bilinear and scale source files * clang-format * reduced no of tests --------- Co-authored-by: Wojciech Laskowski <wojciech.laskowski@streamhpc.com>
This commit is contained in:
committed by
GitHub
parent
6a6177a246
commit
2377a62837
@@ -791,6 +791,18 @@ struct UnaryAbs
|
||||
{
|
||||
y = ck::type_convert<bhalf_t>(ck::math::abs(x));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = ck::type_convert<int8_t>(ck::math::abs(x));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
y = ck::type_convert<half_t>(ck::math::abs(x));
|
||||
};
|
||||
};
|
||||
|
||||
struct UnarySqrt
|
||||
@@ -913,6 +925,20 @@ struct Relu
|
||||
float y_f32 = x > 0 ? x : 0;
|
||||
y = type_convert<bhalf_t>(y_f32);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
float y_f32 = x > 0 ? x : 0;
|
||||
y = type_convert<int8_t>(y_f32);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
float y_f32 = x > 0 ? x : 0;
|
||||
y = type_convert<half_t>(y_f32);
|
||||
};
|
||||
};
|
||||
|
||||
// Fast GeLU
|
||||
@@ -1081,6 +1107,20 @@ struct Sigmoid
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(one / (one + math::exp(-x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<int8_t>(one / (one + math::exp(-x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<half_t>(one / (one + math::exp(-x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Silu
|
||||
@@ -1121,6 +1161,18 @@ struct TanH
|
||||
{
|
||||
y = type_convert<bhalf_t>(math::tanh(x));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(math::tanh(x));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<half_t>(math::tanh(x));
|
||||
};
|
||||
};
|
||||
|
||||
struct ACos
|
||||
@@ -1453,6 +1505,21 @@ struct SoftRelu
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(math::log(one + math::exp(x * alpha_)) / alpha_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<int8_t>(math::log(one + math::exp(x * alpha_)) / alpha_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<half_t>(math::log(one + math::exp(x * alpha_)) / alpha_);
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1487,6 +1554,20 @@ struct Power
|
||||
y = type_convert<bhalf_t>(math::pow(shifted_scaled_x, gamma_));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
const float shifted_scaled_x = alpha_ + beta_ * x;
|
||||
y = type_convert<int8_t>(math::pow(shifted_scaled_x, gamma_));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
const float shifted_scaled_x = alpha_ + beta_ * x;
|
||||
y = type_convert<half_t>(math::pow(shifted_scaled_x, gamma_));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
const float beta_;
|
||||
const float gamma_;
|
||||
@@ -1519,6 +1600,18 @@ struct ClippedRelu
|
||||
y = type_convert<bhalf_t>(math::min(beta_, math::max(alpha_, x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(math::min(beta_, math::max(alpha_, x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<half_t>(math::min(beta_, math::max(alpha_, x)));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
const float beta_;
|
||||
};
|
||||
@@ -1549,6 +1642,18 @@ struct LeakyRelu
|
||||
y = type_convert<bhalf_t>(x >= 0 ? x : x * alpha_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x >= 0 ? x : x * alpha_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<half_t>(x >= 0 ? x : x * alpha_);
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1578,6 +1683,18 @@ struct Elu
|
||||
y = type_convert<bhalf_t>(x > 0 ? x : alpha_ * math::expm1(x));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x > 0 ? x : alpha_ * math::expm1(x));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<half_t>(x > 0 ? x : alpha_ * math::expm1(x));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
@@ -1608,6 +1725,21 @@ struct Logistic
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<bhalf_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<int8_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
|
||||
{
|
||||
constexpr float one = 1.f;
|
||||
y = type_convert<half_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
|
||||
};
|
||||
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
|
||||
@@ -293,6 +293,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
// convolution forward. For some reason for that specific type there is an ambiguity
|
||||
// in the type resolution for the ternary expression. I added an explicit cast to
|
||||
// disambiguate and only use it for f8 just in case it affects performance.
|
||||
// TODO: Add same exception for ck::f8_fnuz_t?
|
||||
if constexpr(is_same_v<scalar_t, ck::f8_ocp_t>)
|
||||
{
|
||||
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
|
||||
|
||||
Reference in New Issue
Block a user