[GEMM] F8 GEMM, performance optimized. (#1384)

* add ab_scale init support

* enabled interwave

* add scale type; update isSupport

* adjust example

* clean

* enable f8 pure gemm rcr ckprofiler

* Add gemm_multiply_multiply instances

* clang format

* Optimize for ScaleBlockMNK=128

* enable abscale f8 gemm ck profiler

* Add pure f8 gemm test suite

* Reverting to the state of project at f60fd77

* update copyright

* clang format

* update copyright

---------

Co-authored-by: root <jizhan@amd.com>
This commit is contained in:
Haocong WANG
2024-07-19 22:06:52 +08:00
committed by GitHub
parent c544eb4da0
commit 8c90f25be3
59 changed files with 7106 additions and 234 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -249,6 +249,31 @@ struct MultiplyAdd
}
};
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::half_t>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, float, float>(
ck::bhalf_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::bhalf_t>(x0_f);
}
};
struct MultiplyAddFastGelu
{
template <typename E, typename C, typename D0, typename D1>