mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
WMMA gemm_add_relu_add_layernorm (#2989)
* Summary:
- Refactor epilogue (with CShuffle) to support fused operations:
- EpilogueCShuffleBase holds common parts
- EpilogueCShuffle: runs CShuffle and write out
- EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out
- Extend thread transfer v7r3:
- Support for intermediate data type different from src and dst type
- New functionality to write to dst buffer and keep data (to be able to use them for additional operations)
* Adress review comments
This commit is contained in:
@@ -60,8 +60,8 @@ struct AddReluAdd
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
|
||||
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
|
||||
__host__ __device__ constexpr void operator()<float, float, half_t, half_t>(
|
||||
float& y, const float& x0, const half_t& x1, const half_t& x2) const
|
||||
{
|
||||
float a = x0 + x1;
|
||||
float b = a > 0 ? a : 0;
|
||||
@@ -69,6 +69,15 @@ struct AddReluAdd
|
||||
y = c;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
|
||||
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
|
||||
{
|
||||
float y_float;
|
||||
(*this)(y_float, x0, x1, x2);
|
||||
y = y_float;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
|
||||
bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const
|
||||
|
||||
Reference in New Issue
Block a user