mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
fixed fp8 issues (#894)
* fixed fp8 init; and reference gemm * Update host_tensor_generator.hpp * fixed convert * fixed reference gemm * fixed comments * fixed comments * fixed ci * fixed computeType --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -27,6 +27,12 @@ struct PassThrough
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
@@ -81,6 +87,12 @@ struct PassThrough
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
|
||||
@@ -416,14 +428,19 @@ struct Swish
|
||||
{
|
||||
Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value,
|
||||
static_assert(is_same<X, float>::value || is_same<X, double>::value ||
|
||||
is_same<X, ck::half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
|
||||
static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
|
||||
is_same<Y, ck::half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
float bx = -beta_ * type_convert<float>(x);
|
||||
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
|
||||
};
|
||||
|
||||
float beta_ = 1.0f;
|
||||
|
||||
Reference in New Issue
Block a user