mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Multiple fixes to GroupedGemm+SplitK (#707)
* Add license header.
* Reduce number of logged output. Add constant initialization.
* Add functional tests for grouped_gemm with different kbatch value.
* Add debug log informations + remove unused code.
* Don't pass kbatch to CalculateKPadded.
* Turn on logging in grouped gemm and gemm splitk profiler
* Debug: limit number of test cases to run;
* Log more information and initialize with constant value.
* Turn on DEBUG_LOG
* Add more debug log informations.
* Limit the number of instances to compile.
* Use GridwiseGemmPipeline
* Use KBatch to calculate K0
* Multiple DebugLog messages.
* Unit tests for multiple KBatch values.
* Refactoring
* Disable logging
* extract out of if statement KBatch update.
* Uncomment instances.
* Disable DebugLog.
* Use Kbatch when calculate KPadded.
* Fix CGridDesc padding.
* Use available helper functions.
* Uncomment code commented for debuggin.
* Remove unnecessary debug log messages.
* Uncomment previously commented code for debug purposes.
* Add KBatch info to profiler output summary log.
* Add gtests for gemm splitk using ckProfiler API.
* Add more test-cases for different data layout.
* Add more test cases for gemm splitk
* Remove old test.
* Unit tests for MKNK ggemm interface.
* Fix and add more unit-tests.
* Constepxr everything!
* Increase error threshold for fp16 and splitk.
Since we're using fp16 atomic add for splitk there's a
known precision loss.
---------
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 70e4eb567f]
This commit is contained in:
66
test/gemm_split_k/test_gemm_splitk.cpp
Normal file
66
test/gemm_split_k/test_gemm_splitk.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_splitk_util.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_MK_KN
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_MK_NK
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_KM_KN
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_KM_NK
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ADataType, BDataType, CDataType
|
||||
std::tuple< F16, F16, F16>,
|
||||
std::tuple< F32, F32, F32>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes);
|
||||
|
||||
#include "test_gemm_splitk_ut_cases.inc"
|
||||
Reference in New Issue
Block a user