mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Multiple fixes to GroupedGemm+SplitK (#707)
* Add license header.
* Reduce number of logged output. Add constant initialization.
* Add functional tests for grouped_gemm with different kbatch value.
* Add debug log informations + remove unused code.
* Don't pass kbatch to CalculateKPadded.
* Turn on logging in grouped gemm and gemm splitk profiler
* Debug: limit number of test cases to run;
* Log more information and initialize with constant value.
* Turn on DEBUG_LOG
* Add more debug log informations.
* Limit the number of instances to compile.
* Use GridwiseGemmPipeline
* Use KBatch to calculate K0
* Multiple DebugLog messages.
* Unit tests for multiple KBatch values.
* Refactoring
* Disable logging
* extract out of if statement KBatch update.
* Uncomment instances.
* Disable DebugLog.
* Use Kbatch when calculate KPadded.
* Fix CGridDesc padding.
* Use available helper functions.
* Uncomment code commented for debuggin.
* Remove unnecessary debug log messages.
* Uncomment previously commented code for debug purposes.
* Add KBatch info to profiler output summary log.
* Add gtests for gemm splitk using ckProfiler API.
* Add more test-cases for different data layout.
* Add more test cases for gemm splitk
* Remove old test.
* Unit tests for MKNK ggemm interface.
* Fix and add more unit-tests.
* Constepxr everything!
* Increase error threshold for fp16 and splitk.
Since we're using fp16 atomic add for splitk there's a
known precision loss.
---------
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 70e4eb567f]
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
|
||||
add_test_executable(test_gemm_split_k gemm_split_k.cpp)
|
||||
target_link_libraries(test_gemm_split_k PRIVATE utility)
|
||||
target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance)
|
||||
add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp)
|
||||
target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance)
|
||||
endif()
|
||||
|
||||
@@ -1,261 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#include "ck/library/utility/host_gemm.hpp"
|
||||
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float max_diff = 1e-6;
|
||||
|
||||
for(std::size_t i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct gemmArgs
|
||||
{
|
||||
GemmMatrixLayout layout;
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int StrideA;
|
||||
int StrideB;
|
||||
int StrideC;
|
||||
int KBatch;
|
||||
};
|
||||
|
||||
int test_gemm(const gemmArgs& args)
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
bool a_row_major, b_row_major, c_row_major;
|
||||
|
||||
switch(args.layout)
|
||||
{
|
||||
case GemmMatrixLayout::MK_KN_MN:
|
||||
a_row_major = true;
|
||||
b_row_major = true;
|
||||
c_row_major = true;
|
||||
break;
|
||||
case GemmMatrixLayout::MK_NK_MN:
|
||||
a_row_major = true;
|
||||
b_row_major = false;
|
||||
c_row_major = true;
|
||||
break;
|
||||
case GemmMatrixLayout::KM_KN_MN:
|
||||
a_row_major = false;
|
||||
b_row_major = true;
|
||||
c_row_major = true;
|
||||
break;
|
||||
case GemmMatrixLayout::KM_NK_MN:
|
||||
a_row_major = false;
|
||||
b_row_major = false;
|
||||
c_row_major = true;
|
||||
break;
|
||||
default: printf("not supported layout"); return 1;
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, bool row_major) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(row_major)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<float> a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major));
|
||||
Tensor<float> b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major));
|
||||
Tensor<float> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
Tensor<float> c_m_n_device_result(
|
||||
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
|
||||
|
||||
// init data
|
||||
std::size_t num_thread = 1;
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
|
||||
// set zero to c_device_buf
|
||||
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<float>{}, num_thread);
|
||||
|
||||
host_gemm_mk_kn_mn(a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_host_result,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
|
||||
bool success = false;
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<decltype(a_layout),
|
||||
decltype(b_layout),
|
||||
decltype(c_layout),
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
const auto gemm_ptrs =
|
||||
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
for(auto& gemm_ptr : gemm_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(c_device_buf.GetDeviceBuffer()),
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
args.StrideA,
|
||||
args.StrideB,
|
||||
args.StrideC,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
args.KBatch);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get());
|
||||
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
if(!check_out(c_m_n_host_result, c_m_n_device_result))
|
||||
{
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
success = true;
|
||||
}
|
||||
}
|
||||
|
||||
return success;
|
||||
};
|
||||
|
||||
bool success = false;
|
||||
|
||||
if(args.layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
success = test(Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(args.layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
success = test(Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(args.layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
success = test(Col{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
success = test(Col{}, Col{}, Row{});
|
||||
}
|
||||
|
||||
auto error_code = 0;
|
||||
if(success)
|
||||
{
|
||||
std::cout << "test split k : Pass" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "test split k: Fail " << std::endl;
|
||||
error_code = -1; // test needs to report failure
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
std::vector<gemmArgs> test_cases;
|
||||
if(argc == 1)
|
||||
{
|
||||
test_cases = {{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 2},
|
||||
{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 8}};
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
|
||||
const int M = std::stoi(argv[2]);
|
||||
const int N = std::stoi(argv[3]);
|
||||
const int K = std::stoi(argv[4]);
|
||||
|
||||
const int StrideA = std::stoi(argv[5]);
|
||||
const int StrideB = std::stoi(argv[6]);
|
||||
const int StrideC = std::stoi(argv[7]);
|
||||
const int KBatch = std::stoi(argv[8]);
|
||||
test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
|
||||
return -1;
|
||||
}
|
||||
bool error = false;
|
||||
for(const auto& kinder : test_cases)
|
||||
{
|
||||
error |= test_gemm(kinder);
|
||||
}
|
||||
return error ? 1 : 0;
|
||||
}
|
||||
66
test/gemm_split_k/test_gemm_splitk.cpp
Normal file
66
test/gemm_split_k/test_gemm_splitk.cpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_splitk_util.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_MK_KN
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_MK_NK
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_KM_KN
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK_KM_NK
|
||||
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ADataType, BDataType, CDataType
|
||||
std::tuple< F16, F16, F16>,
|
||||
std::tuple< F32, F32, F32>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes);
|
||||
|
||||
#include "test_gemm_splitk_ut_cases.inc"
|
||||
217
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
Normal file
217
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
Normal file
@@ -0,0 +1,217 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_KN, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 320;
|
||||
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_KN, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_NK, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{127};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 437;
|
||||
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideA = K;
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = N;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmSplitK_KM_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
|
||||
constexpr int StrideB = K;
|
||||
constexpr int StrideC = N;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K, M, StrideB, StrideC);
|
||||
}
|
||||
78
test/gemm_split_k/test_gemm_splitk_util.hpp
Normal file
78
test/gemm_split_k/test_gemm_splitk_util.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "include/ck/utility/data_type.hpp"
|
||||
#include "profiler/profile_gemm_splitk_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmSplitK : public testing::Test
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using ADataType = std::tuple_element_t<2, Tuple>;
|
||||
using BDataType = std::tuple_element_t<3, Tuple>;
|
||||
using CDataType = std::tuple_element_t<4, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // decimal value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
|
||||
|
||||
void Run(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA,
|
||||
const int StrideB,
|
||||
const int StrideC)
|
||||
{
|
||||
for(auto kb : k_batches_)
|
||||
{
|
||||
RunSingle(M, N, K, StrideA, StrideB, StrideC, kb);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(const int M,
|
||||
const int N,
|
||||
const int K,
|
||||
const int StrideA,
|
||||
const int StrideB,
|
||||
const int StrideC,
|
||||
int kbatch = 1)
|
||||
{
|
||||
bool pass = ck::profiler::profile_gemm_splitk_impl<ADataType,
|
||||
BDataType,
|
||||
F32,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, kbatch);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,9 @@
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
|
||||
add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp)
|
||||
target_link_libraries(test_grouped_gemm_fp16 PRIVATE utility)
|
||||
target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance)
|
||||
add_custom_target(test_grouped_gemm)
|
||||
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
|
||||
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp)
|
||||
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
|
||||
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
|
||||
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface)
|
||||
endif()
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
|
||||
#include "profiler/profile_grouped_gemm_impl.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
bool TestGroupedGemm()
|
||||
{
|
||||
|
||||
std::mt19937 gen(19391);
|
||||
std::uniform_int_distribution<> distrib(1, 10);
|
||||
int group_count = distrib(gen);
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<const void*> p_a, p_b;
|
||||
std::vector<void*> p_c;
|
||||
|
||||
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideCs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * distrib(gen));
|
||||
Ns.push_back(256 + 256 * distrib(gen));
|
||||
Ks.push_back(128 + 128 * distrib(gen));
|
||||
|
||||
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
|
||||
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
|
||||
StrideCs.push_back(std::is_same<Row, CLayout>::value ? Ns[i] : Ms[i]);
|
||||
}
|
||||
|
||||
return ck::profiler::profile_grouped_gemm_impl<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
true, 1, false, 1, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
int main()
|
||||
{
|
||||
bool res = true;
|
||||
|
||||
res = res && TestGroupedGemm<Row, Row, Row>();
|
||||
res = res && TestGroupedGemm<Row, Col, Row>();
|
||||
res = res && TestGroupedGemm<Col, Row, Row>();
|
||||
res = res && TestGroupedGemm<Col, Col, Row>();
|
||||
|
||||
std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
|
||||
return res ? 0 : 1;
|
||||
}
|
||||
202
test/grouped_gemm/test_grouped_gemm_interface.cpp
Normal file
202
test/grouped_gemm/test_grouped_gemm_interface.cpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using ELayout = Row;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock>
|
||||
using GGemmInstance =
|
||||
ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmSpec,
|
||||
KPerBlock,
|
||||
K1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 8>;
|
||||
};
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 188, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 128;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// M % MPerBlock
|
||||
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ms = std::vector<int>{256, 128, 128, 512};
|
||||
Ns = std::vector<int>{256, 177, 128, 512};
|
||||
// N % NPerBlock
|
||||
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
}
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
|
||||
{
|
||||
static constexpr auto GemmMNKPadding =
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 4, 8, 8>;
|
||||
|
||||
std::vector<int> Ms{128, 256, 256, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// K % ABlockTransferSrcScalarPerVector
|
||||
Ks = std::vector<int>{256, 177, 128, 512};
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ks = std::vector<int>{256, 164, 128, 512};
|
||||
// K % BBlockTransferSrcScalarPerVector
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ks = std::vector<int>(4, 128);
|
||||
Ns = std::vector<int>{256, 127, 128, 512};
|
||||
// N % CBlockTransferScalarPerVector_NWaveNPerXDL
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
}
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 256, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 128;
|
||||
constexpr int kbatch = 4;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// kloops % 2
|
||||
Ks = std::vector<int>{256, 512, 320, 768};
|
||||
EXPECT_FALSE(
|
||||
DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch));
|
||||
|
||||
// Not all gemms have same value for main_k0_block_loop!
|
||||
Ks = std::vector<int>{256, 512, 512, 512};
|
||||
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
|
||||
std::runtime_error);
|
||||
}
|
||||
|
||||
class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
using ELayout = Col;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t CDEBlockTransferScalarPerVector_NPerBlock>
|
||||
using GGemmInstance =
|
||||
ck::test::DeviceGroupedGemmSplitkInstanceWrapper<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmSpec,
|
||||
KPerBlock,
|
||||
K1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
using DefaultGGemmInstance = GGemmInstance<GemmDefault, 32, 8, 4, 8, 4>;
|
||||
};
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 188, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 128;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// M % MPerBlock
|
||||
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ms = std::vector<int>{128, 256, 256, 512};
|
||||
Ns = std::vector<int>{256, 177, 128, 512};
|
||||
// N % NPerBlock
|
||||
EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
}
|
||||
|
||||
TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth)
|
||||
{
|
||||
static constexpr auto GemmMNKPadding =
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
using PaddedGGemmInstance = GGemmInstance<GemmMNKPadding, 32, 8, 2, 8, 4>;
|
||||
|
||||
std::vector<int> Ms{128, 256, 256, 512};
|
||||
constexpr int N = 256;
|
||||
constexpr int K = 512;
|
||||
|
||||
std::vector<int> Ns(Ms.size(), N);
|
||||
std::vector<int> Ks(Ms.size(), K);
|
||||
std::vector<int> StrideAs(Ms.size(), K);
|
||||
std::vector<int> StrideBs(Ms.size(), K);
|
||||
std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
// M % ABlockTransferSrcScalarPerVector
|
||||
Ms = std::vector<int>{256, 177, 128, 512};
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ms = std::vector<int>{128, 256, 256, 512};
|
||||
Ns = std::vector<int>{256, 164, 128, 512};
|
||||
// N % BBlockTransferSrcScalarPerVector
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
|
||||
Ns = std::vector<int>{128, 256, 256, 512};
|
||||
Ms = std::vector<int>{256, 130, 128, 512};
|
||||
// M % CBlockTransferScalarPerVector_NWaveNPerXDL
|
||||
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
|
||||
}
|
||||
34
test/grouped_gemm/test_grouped_gemm_splitk.cpp
Normal file
34
test/grouped_gemm/test_grouped_gemm_splitk.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
|
||||
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
|
||||
|
||||
using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
|
||||
using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
|
||||
|
||||
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN,
|
||||
RRR_F16_F16_F16_LargeK,
|
||||
testing::Values(32, 64));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK,
|
||||
RCR_F16_F16_F16_LargeK,
|
||||
testing::Values(32, 64));
|
||||
|
||||
#include "test_grouped_gemm_ut_cases.inc"
|
||||
180
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
Normal file
180
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
Normal file
@@ -0,0 +1,180 @@
|
||||
#pragma once
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, TinyCases)
|
||||
{
|
||||
const std::vector<int> Ms{0, 1};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, SmallCases)
|
||||
{
|
||||
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, MidCases)
|
||||
{
|
||||
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, Regular)
|
||||
{
|
||||
const std::vector<int> Ms{64, 128, 256};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 320;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, TinyCases)
|
||||
{
|
||||
const std::vector<int> Ms{0, 1};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, SmallCases)
|
||||
{
|
||||
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, MidCases)
|
||||
{
|
||||
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, Regular)
|
||||
{
|
||||
const std::vector<int> Ms{32, 64, 128, 256};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 320;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
|
||||
{
|
||||
const std::vector<int> Ms{188, 210};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 4096;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch)
|
||||
{
|
||||
const std::vector<int> Ms{188, 210};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 4096;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
249
test/grouped_gemm/test_grouped_gemm_util.hpp
Normal file
249
test/grouped_gemm/test_grouped_gemm_util.hpp
Normal file
@@ -0,0 +1,249 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/stream_config.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "profiler/profile_grouped_gemm_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename Range>
|
||||
std::string serialize_range(const Range& range)
|
||||
{
|
||||
std::stringstream ss;
|
||||
for(auto& r : range)
|
||||
{
|
||||
ss << r << ", ";
|
||||
}
|
||||
std::string str = ss.str();
|
||||
return std::string(str.begin(), str.end() - 2);
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemm : public testing::TestWithParam<int>
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using ELayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // decimal value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
|
||||
void SetUp() override {}
|
||||
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1)
|
||||
{
|
||||
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
float,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(
|
||||
verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock>
|
||||
struct DeviceGroupedGemmSplitkInstanceWrapper
|
||||
{
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using EmptyTuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using ABlockTransferThreadClusterArrageOrder =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
|
||||
using ABlockTransferSrcAccessOrder =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
|
||||
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
|
||||
using ABlockTransferDstScalarPerVector_K1 =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
|
||||
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
|
||||
|
||||
using BBlockTransferThreadClusterArrageOrder =
|
||||
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
|
||||
using BBlockTransferSrcAccessOrder =
|
||||
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
|
||||
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
|
||||
using BBlockTransferDstScalarPerVector_K1 =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
|
||||
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
|
||||
|
||||
using DeviceGroupedGemmSplitKInstance =
|
||||
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
|
||||
ALayout,
|
||||
BLayout,
|
||||
EmptyTuple,
|
||||
ELayout,
|
||||
F16,
|
||||
F16,
|
||||
F32,
|
||||
F16,
|
||||
EmptyTuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
GemmSpec,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
KPerBlock,
|
||||
K1,
|
||||
K1,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<1, 4, 32, 1>,
|
||||
ABlockTransferThreadClusterArrageOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim::value,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1::value,
|
||||
ABlockLdsAddExtraM::value,
|
||||
S<1, 4, 32, 1>,
|
||||
BBlockTransferThreadClusterArrageOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim::value,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1::value,
|
||||
BBlockLdsAddExtraM::value,
|
||||
1,
|
||||
1,
|
||||
S<1, 16, 1, 8>,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
bool IsSupported(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1) const
|
||||
{
|
||||
std::size_t n_groups = Ms.size();
|
||||
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
|
||||
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
|
||||
<< "The number of groups is not consistent!";
|
||||
|
||||
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
for(std::size_t i = 0; i < n_groups; ++i)
|
||||
{
|
||||
gemm_descs.push_back(tensor_operation::device::GemmDesc{
|
||||
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
}
|
||||
|
||||
std::vector<const void*> p_As(n_groups, nullptr);
|
||||
std::vector<const void*> p_Bs(n_groups, nullptr);
|
||||
std::vector<void*> p_Cs(n_groups, nullptr);
|
||||
auto p_Ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
|
||||
auto argument = ggemm_instance.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
if(kbatch > 1)
|
||||
{
|
||||
ggemm_instance.SetKBatchSize(argument, kbatch);
|
||||
}
|
||||
|
||||
return ggemm_instance.IsSupportedArgument(argument);
|
||||
}
|
||||
|
||||
float Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1) const
|
||||
{
|
||||
std::size_t n_groups = Ms.size();
|
||||
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
|
||||
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
|
||||
<< "The number of groups is not consistent!";
|
||||
|
||||
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
for(std::size_t i = 0; i < n_groups; ++i)
|
||||
{
|
||||
gemm_descs.push_back(tensor_operation::device::GemmDesc{
|
||||
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
}
|
||||
|
||||
std::vector<const void*> p_As(n_groups, nullptr);
|
||||
std::vector<const void*> p_Bs(n_groups, nullptr);
|
||||
std::vector<void*> p_Cs(n_groups, nullptr);
|
||||
auto p_Ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
|
||||
auto argument = ggemm_instance.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
if(kbatch > 1)
|
||||
{
|
||||
ggemm_instance.SetKBatchSize(argument, kbatch);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
|
||||
auto invoker = ggemm_instance.MakeInvoker();
|
||||
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument));
|
||||
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
|
||||
return invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user