mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Single-kernel GEMM + layernorm (#263)
* dump lds content in appropriate precision type * add squared add reduction op; allows sq sum * initial stub from regular gemm impl * layernorm example code & host verification * initial layernorm implementation * tidy up * make C0 precision type consistent with C * clang-tidy and additional comments * tighten up example code * account for extra flops/bytes from normalization * clang-format * c0 bias/beta/gamma now have its own precision type * AccElemOp for gemm outputs prior to feeding to layernorm * update workgroup mapping * rename kernel template param to reflect its dual use * use LDS mem pool for reduction workspace * change cshuffle precision type to f16; clean up * clang-format * correct naming * explicit cast * fully implemented gemm + bias + activation + add + norm * activation in correct order * reflect reduction API's recent change * amend * clean up; add comment * keep up with recent changes in reduction API * format * resolve merge conflicts Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -58,6 +58,33 @@ struct Add
|
||||
}
|
||||
};
|
||||
|
||||
struct SquaredAdd
|
||||
{
|
||||
template <class T>
|
||||
__host__ __device__ static constexpr T GetIdentityValue()
|
||||
{
|
||||
return type_convert<T>(0.0f);
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
{
|
||||
return operation == InMemoryDataOperationEnum::AtomicAdd ||
|
||||
operation == InMemoryDataOperationEnum::Set;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ inline constexpr void operator()(T& a, T b) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
|
||||
is_same<T, int8_t>::value,
|
||||
"The data type is not supported by the Max accumulator!");
|
||||
|
||||
a = a + b * b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Mul
|
||||
{
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user