mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Gemm multiple d multiple r (#335)
* Imitate XXX_gemm_multiple_d, add XXX_gemm_multiple_d_multiple_r for gemm + reduction * Implement run of kernel * Add example * Fix parameter of typo * Rewrite the reduceMax example * Rewrite the reduceMean + reduceMeanSquare example * Refine naming * Refine folder name * refine naming * Rewrite the gemm + bias + relu + add + layernorm example * Rewrite the gemm + layernorm example * clang-format * Fix bug if sync lds * Fix compile error
This commit is contained in:
@@ -130,6 +130,35 @@ struct AddHardswishAdd
|
||||
}
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
// E = C + D0 + D1
|
||||
struct AddAdd
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
|
||||
{
|
||||
// Only support floating so far
|
||||
static_assert(is_same<E, half_t>::value || is_same<E, float>::value ||
|
||||
is_same<E, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
static_assert(is_same<C, half_t>::value || is_same<C, float>::value ||
|
||||
is_same<C, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
static_assert(is_same<D0, half_t>::value || is_same<D0, float>::value ||
|
||||
is_same<D0, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
static_assert(is_same<D1, half_t>::value || is_same<D1, float>::value ||
|
||||
is_same<D1, double>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
|
||||
e = type_convert<E>(y);
|
||||
}
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
|
||||
Reference in New Issue
Block a user