mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[GEMM] F8 GEMM, performance optimized. (#1384)
* add ab_scale init support
* enabled interwave
* add scale type; update isSupport
* adjust example
* clean
* enable f8 pure gemm rcr ckprofiler
* Add gemm_multiply_multiply instances
* clang format
* Optimize for ScaleBlockMNK=128
* enable abscale f8 gemm ck profiler
* Add pure f8 gemm test suite
* Reverting to the state of project at f60fd77
* update copyright
* clang format
* update copyright
---------
Co-authored-by: root <jizhan@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -25,12 +25,13 @@ class TestGemmUniversal : public testing::Test
|
||||
using F32 = float;
|
||||
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using ADataType = std::tuple_element_t<2, Tuple>;
|
||||
using BDataType = std::tuple_element_t<3, Tuple>;
|
||||
using CDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using ADataType = std::tuple_element_t<2, Tuple>;
|
||||
using BDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<4, Tuple>;
|
||||
using CDataType = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
@@ -66,6 +67,7 @@ class TestGemmUniversal : public testing::Test
|
||||
{
|
||||
bool pass = ck::profiler::profile_gemm_universal_impl<ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
F32,
|
||||
CDataType,
|
||||
ALayout,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
@@ -41,16 +41,24 @@ class TestGemmUniversal_MK_NK
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ADataType, BDataType, CDataType
|
||||
std::tuple< F16, F16, F16>,
|
||||
std::tuple< F16, F8, F16>,
|
||||
std::tuple< F8, F16, F16>,
|
||||
std::tuple< BF16, BF16, BF16>
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< BF16, BF16, BF16, BF16>
|
||||
>;
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< BF16, BF16, BF16, BF16>,
|
||||
std::tuple< F8, F8, F8, BF16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN);
|
||||
TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
#include "test_gemm_universal_ut_cases.inc"
|
||||
|
||||
Reference in New Issue
Block a user