WMMA gemm_add_relu_add_layernorm (#2989)

* Summary:

 - Refactor epilogue (with CShuffle) to support fused operations:
    - EpilogueCShuffleBase holds common parts
    - EpilogueCShuffle: runs CShuffle and write out
    - EpilogueWelfordCShuffle: holds Welford specific arguments, runs CShuffle, write out, Welford first part and Welford write out

 - Extend thread transfer v7r3:
    - Support for intermediate data type different from src and dst type
    - New functionality to write to dst buffer and keep data (to be able to use them for additional operations)

* Adress review comments
This commit is contained in:
Enrico Degregori
2025-10-31 19:19:26 +01:00
committed by GitHub
parent e9596228ff
commit 4ebc48a3cd
23 changed files with 2678 additions and 332 deletions

View File

@@ -60,8 +60,8 @@ struct AddReluAdd
}
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
__host__ __device__ constexpr void operator()<float, float, half_t, half_t>(
float& y, const float& x0, const half_t& x1, const half_t& x2) const
{
float a = x0 + x1;
float b = a > 0 ? a : 0;
@@ -69,6 +69,15 @@ struct AddReluAdd
y = c;
}
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
{
float y_float;
(*this)(y_float, x0, x1, x2);
y = y_float;
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const