mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
bf16A_Int8B with fastgelu/bias (#1264)
* changed the copy function to v7r2 * adding multi_abd * in-progress * add post-load oob check * debugging * adjust instances * add run_lds * add elemntwise_op * replace multi_abd_device with v3 * clean up * clean * clean * Added LDSType * profiling * adjust oobcheck * add missing file * refactor * clean * add examples
This commit is contained in:
@@ -92,15 +92,6 @@ struct Add
|
||||
};
|
||||
};
|
||||
|
||||
struct Scales
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
y = ck::type_convert<Y>(ck::type_convert<float>(x0) * ck::type_convert<float>(x1));
|
||||
}
|
||||
};
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
@@ -188,6 +179,16 @@ struct Multiply
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x0);
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp * x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
@@ -521,6 +522,71 @@ struct AddFastGelu
|
||||
}
|
||||
};
|
||||
|
||||
// E = MultiplyFastGelu(C + D)
|
||||
struct MultiplyFastGelu
|
||||
{
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& e, const float& c, const float& d) const
|
||||
{
|
||||
const float x = c * d;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
|
||||
{
|
||||
const half_t x = c * d;
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
|
||||
{
|
||||
const float x0_f = c * d;
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
|
||||
x0_f);
|
||||
|
||||
e = type_convert<half_t>(x1_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
|
||||
{
|
||||
const float x0_f = type_convert<float>(c) * type_convert<float>(d);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = type_convert<bhalf_t>(x1_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
|
||||
{
|
||||
const float x0_f = c * type_convert<float>(d);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = type_convert<bhalf_t>(x1_f);
|
||||
}
|
||||
};
|
||||
|
||||
// E = Silu(C + D)
|
||||
struct AddSilu
|
||||
{
|
||||
|
||||
@@ -221,6 +221,15 @@ struct MultiplyAdd
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
|
||||
const float& c,
|
||||
const bhalf_t& d0,
|
||||
const bhalf_t& d1) const
|
||||
{
|
||||
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
|
||||
const float& c,
|
||||
const half_t& d0,
|
||||
@@ -240,6 +249,26 @@ struct MultiplyAdd
|
||||
}
|
||||
};
|
||||
|
||||
struct MultiplyAddFastGelu
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
|
||||
ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
|
||||
{
|
||||
const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = ck::type_convert<ck::bhalf_t>(x1_f);
|
||||
}
|
||||
};
|
||||
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
|
||||
@@ -504,6 +504,16 @@ struct FastGelu
|
||||
y = type_convert<half_t>(y_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
float y_f;
|
||||
|
||||
this->operator()<float, float>(y_f, x);
|
||||
|
||||
y = type_convert<bhalf_t>(y_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user