[CK] Add command option instance_index and param_mask to run partial ck test (#2889)

* [CK] Add command option instance_index and param_mask to run partial ck test

Many CK test are instance test. it will loop all instance in the instance library. It causes test often out-of-time if we run test on simulator/emulator.
This PR add option instance_index and param_mask to reduce the workload of instance test

instance_index: only run test 1 available instance with specified index.
param_mask: filter the embedded parameter with specified mask

* fix CI error

* fix clang format

---------

Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
linqunAMD
2025-09-30 23:24:40 +08:00
committed by GitHub
parent 28ad8ae5d8
commit e78a897ec0
113 changed files with 2804 additions and 704 deletions

View File

@@ -31,4 +31,4 @@ using AccDataType = float;
#include "run_gemm_test.inc"
int main() { return run_gemm_test(); }
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }

View File

@@ -31,4 +31,4 @@ using AccDataType = float;
#include "run_gemm_test.inc"
int main() { return run_gemm_test(); }
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }

View File

@@ -31,4 +31,4 @@ using AccDataType = float;
#include "run_gemm_test.inc"
int main() { return run_gemm_test(); }
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }

View File

@@ -31,4 +31,4 @@ using AccDataType = double;
#include "run_gemm_test.inc"
int main() { return run_gemm_test(); }
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }

View File

@@ -31,4 +31,4 @@ using AccDataType = int32_t;
#include "run_gemm_test.inc"
int main() { return run_gemm_test(); }
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }

View File

@@ -105,6 +105,7 @@ int main(int argc, char* argv[])
bool do_verification = true;
bool time_kernel = true;
int problem_index = -1;
if(argc == 1)
{
@@ -115,16 +116,28 @@ int main(int argc, char* argv[])
do_verification = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
problem_index = std::stoi(argv[3]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: time kernel (0=no, 1=yes)" << std::endl;
<< "arg2: time kernel (0=no, 1=yes)" << std::endl
<< "arg3: problem index (0-35, -1 means all)" << std::endl;
return 0;
}
bool pass = true;
for(auto& p : problems)
for(size_t i = 0; i < problems.size(); i++)
{
if(problem_index != -1 && problem_index != static_cast<ck::index_t>(i))
{
continue;
}
auto& p = problems[i];
GemmParams& problem_size = std::get<0>(p);
const LayoutConfig& layout_config = std::get<1>(p);
const auto& factory = std::get<2>(p);

View File

@@ -261,6 +261,44 @@ struct TestGemm
return true;
}
}
template <template <class...> class DeviceGemmPtr_,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
bool IsSupportedArgument(DeviceGemmPtr_<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>* gemmPtr,
const GemmParams& params = GemmParams{})
{
auto invoker_ptr = gemmPtr->MakeInvokerPointer();
auto argument_ptr = gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(nullptr),
static_cast<BDataType*>(nullptr),
static_cast<CDataType*>(nullptr),
params.M,
params.N,
params.K,
params.StrideA,
params.StrideB,
params.StrideC,
AElementwiseOperation{},
BElementwiseOperation{},
CElementwiseOperation{});
return gemmPtr->IsSupportedArgument(argument_ptr.get());
}
};
} // namespace gemm_util

View File

@@ -1,13 +1,39 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int run_gemm_test()
int run_gemm_test(int argc, char* argv[])
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
ck::gemm_util::GemmParams params;
ck::index_t instance_index = -1;
if(argc == 1)
{
// use default params
}
else if(argc == 4 || argc == 5)
{
params.M = atoi(argv[1]);
params.N = atoi(argv[2]);
params.K = atoi(argv[3]);
params.StrideA = params.M;
params.StrideB = params.N;
params.StrideC = params.K;
if(argc == 5)
{
instance_index = atoi(argv[4]);
}
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1-4: M N K instance_index(-1 means all)" << std::endl;
}
std::cout << "Params (M, N, K, index) " << params.M << " " << params.N << " " << params.K << " "
<< instance_index << std::endl;
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
@@ -24,10 +50,31 @@ int run_gemm_test()
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
ck::index_t num_instance = 0;
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<AccDataType>{}(gemmPtr.get());
if(instance_index == -1)
{
pass &= ck::gemm_util::TestGemm<AccDataType>{}(gemmPtr.get(), params);
}
else
{
auto test_gemm = ck::gemm_util::TestGemm<AccDataType>{};
if(test_gemm.IsSupportedArgument(gemmPtr.get(), params))
{
if(num_instance == instance_index)
{
pass &= test_gemm(gemmPtr.get(), params);
}
num_instance++;
}
}
}
if(instance_index != -1)
{
std::cout << "TestGemm_instance (" << instance_index << "/" << num_instance
<< "): " << (pass ? "Passed" : "Failed") << std::endl;
}
return pass;