mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Add support for mixed-precision f16bf16_int8 gemm (#1127)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -75,6 +75,15 @@ struct Add
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x0 + x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
@@ -264,6 +273,14 @@ struct AddRelu
|
||||
y = a > 0.0f ? a : 0.0f;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
|
||||
@@ -354,6 +371,70 @@ struct AddFastGelu
|
||||
|
||||
e = type_convert<half_t>(x1_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
|
||||
{
|
||||
const float x0_f = c + type_convert<float>(d);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = type_convert<bhalf_t>(x1_f);
|
||||
}
|
||||
};
|
||||
|
||||
// E = Silu(C + D)
|
||||
struct AddSilu
|
||||
{
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& e, const float& c, const float& d) const
|
||||
{
|
||||
const float x = c + d;
|
||||
|
||||
Silu{}.template operator()<float>(e, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
|
||||
{
|
||||
const half_t x = c + d;
|
||||
|
||||
Silu{}.template operator()<half_t>(e, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
|
||||
{
|
||||
const float x0_f = c + d;
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
Silu{}.template operator()<float>(x1_f, x0_f);
|
||||
|
||||
e = type_convert<half_t>(x1_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
|
||||
{
|
||||
const float x0_f = c + type_convert<float>(d);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
Silu{}.template operator()<float>(x1_f, x0_f);
|
||||
|
||||
e = type_convert<bhalf_t>(x1_f);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -156,6 +156,12 @@ struct PassThrough
|
||||
y = type_convert<half_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, int8_t>(bhalf_t& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<bhalf_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
@@ -551,6 +557,19 @@ struct Sigmoid
|
||||
};
|
||||
};
|
||||
|
||||
struct Silu
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, ck::half_t> ||
|
||||
is_same_v<T, int8_t> || is_same_v<T, int32_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
constexpr T one = type_convert<T>(1);
|
||||
y = x * (one / (one + ck::math::exp(-x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct TanH
|
||||
{
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user