mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Fp16/fp8 mixed-precision Gemm with multiply+add fusion (#865)
* add compute_type * add multiply_add ckProfiler * add f8_fp16 support * clean * clean * fixed lds size calc * format --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -195,6 +195,51 @@ struct AddMultiply
|
||||
}
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
// E = C x D0 + D1
|
||||
struct MultiplyAdd
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
|
||||
const half_t& c,
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
{
|
||||
const half_t y = (c * d0) + d1;
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
|
||||
const float& c,
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
{
|
||||
const half_t y = type_convert<half_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
|
||||
const float& c,
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
{
|
||||
const float y = c * d0 + d1;
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<half_t, float, float, float>(half_t& e,
|
||||
const float& c,
|
||||
const float& d0,
|
||||
const float& d1) const
|
||||
{
|
||||
const float y = c * d0 + d1;
|
||||
e = y;
|
||||
}
|
||||
};
|
||||
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user