Files
composable_kernel/test/gemm_split_k/test_gemm_splitk_util.hpp
Adam Osewski 70e4eb567f Multiple fixes to GroupedGemm+SplitK (#707)
* Add license header.

* Reduce number of logged output. Add constant initialization.

* Add functional tests for grouped_gemm with different kbatch value.

* Add debug log informations + remove unused code.

* Don't pass kbatch to CalculateKPadded.

* Turn on logging in grouped gemm and gemm splitk profiler

* Debug: limit number of test cases to run;

* Log more information and initialize with constant value.

* Turn on DEBUG_LOG

* Add more debug log informations.

* Limit the number of instances to compile.

* Use GridwiseGemmPipeline

* Use KBatch to calculate K0

* Multiple DebugLog messages.

* Unit tests for multiple KBatch values.

* Refactoring

* Disable logging
* extract out of if statement KBatch update.

* Uncomment instances.

* Disable DebugLog.

* Use Kbatch when calculate KPadded.

* Fix CGridDesc padding.

* Use available helper functions.

* Uncomment code commented for debuggin.

* Remove unnecessary debug log messages.

* Uncomment previously commented code for debug purposes.

* Add KBatch info to profiler output summary log.

* Add gtests for gemm splitk using ckProfiler API.

* Add more test-cases for different data layout.

* Add more test cases for gemm splitk

* Remove old test.

* Unit tests for MKNK ggemm interface.

* Fix and add more unit-tests.

* Constepxr everything!

* Increase error threshold for fp16 and splitk.

Since we're using fp16 atomic add for splitk there's a
known precision loss.

---------

Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
2023-05-30 07:09:06 -05:00

79 lines
2.5 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/data_type.hpp"
#include "profiler/profile_gemm_splitk_impl.hpp"
namespace ck {
namespace test {
template <typename Tuple>
class TestGemmSplitK : public testing::Test
{
using Row = ck::tensor_layout::gemm::RowMajor;
using F32 = float;
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = Row;
using ADataType = std::tuple_element_t<2, Tuple>;
using BDataType = std::tuple_element_t<3, Tuple>;
using CDataType = std::tuple_element_t<4, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
void Run(const int M,
const int N,
const int K,
const int StrideA,
const int StrideB,
const int StrideC)
{
for(auto kb : k_batches_)
{
RunSingle(M, N, K, StrideA, StrideB, StrideC, kb);
}
}
void RunSingle(const int M,
const int N,
const int K,
const int StrideA,
const int StrideB,
const int StrideC,
int kbatch = 1)
{
bool pass = ck::profiler::profile_gemm_splitk_impl<ADataType,
BDataType,
F32,
CDataType,
ALayout,
BLayout,
CLayout>(
verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, kbatch);
EXPECT_TRUE(pass);
}
};
} // namespace test
} // namespace ck