Add support for mixed-precision f16bf16_int8 gemm (#1127)

This commit is contained in:
jakpiase
2024-02-07 15:54:13 +01:00
committed by GitHub
parent 753cef783f
commit ba86eadce5
33 changed files with 2424 additions and 6 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -75,6 +75,15 @@ struct Add
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x0 + x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
@@ -264,6 +273,14 @@ struct AddRelu
y = a > 0.0f ? a : 0.0f;
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
};
template <>
__host__ __device__ constexpr void
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
@@ -354,6 +371,70 @@ struct AddFastGelu
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
{
const float x0_f = c + type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
};
// E = Silu(C + D)
struct AddSilu
{
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
const float x = c + d;
Silu{}.template operator()<float>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c + d;
Silu{}.template operator()<half_t>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
const float x0_f = c + d;
float x1_f = 0;
Silu{}.template operator()<float>(x1_f, x0_f);
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
{
const float x0_f = c + type_convert<float>(d);
float x1_f = 0;
Silu{}.template operator()<float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
};
} // namespace element_wise

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -156,6 +156,12 @@ struct PassThrough
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, int8_t>(bhalf_t& y, const int8_t& x) const
{
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
@@ -551,6 +557,19 @@ struct Sigmoid
};
};
struct Silu
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, ck::half_t> ||
is_same_v<T, int8_t> || is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck::math::exp(-x)));
};
};
struct TanH
{
template <typename T>