mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK] Add command option instance_index and param_mask to run partial ck test (#2889)
* [CK] Add command option instance_index and param_mask to run partial ck test Many CK test are instance test. it will loop all instance in the instance library. It causes test often out-of-time if we run test on simulator/emulator. This PR add option instance_index and param_mask to reduce the workload of instance test instance_index: only run test 1 available instance with specified index. param_mask: filter the embedded parameter with specified mask * fix CI error * fix clang format --------- Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
@@ -31,4 +31,4 @@ using AccDataType = float;
|
||||
|
||||
#include "run_gemm_test.inc"
|
||||
|
||||
int main() { return run_gemm_test(); }
|
||||
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }
|
||||
|
||||
@@ -31,4 +31,4 @@ using AccDataType = float;
|
||||
|
||||
#include "run_gemm_test.inc"
|
||||
|
||||
int main() { return run_gemm_test(); }
|
||||
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }
|
||||
|
||||
@@ -31,4 +31,4 @@ using AccDataType = float;
|
||||
|
||||
#include "run_gemm_test.inc"
|
||||
|
||||
int main() { return run_gemm_test(); }
|
||||
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }
|
||||
|
||||
@@ -31,4 +31,4 @@ using AccDataType = double;
|
||||
|
||||
#include "run_gemm_test.inc"
|
||||
|
||||
int main() { return run_gemm_test(); }
|
||||
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }
|
||||
|
||||
@@ -31,4 +31,4 @@ using AccDataType = int32_t;
|
||||
|
||||
#include "run_gemm_test.inc"
|
||||
|
||||
int main() { return run_gemm_test(); }
|
||||
int main(int argc, char* argv[]) { return run_gemm_test(argc, argv); }
|
||||
|
||||
@@ -105,6 +105,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
int problem_index = -1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -115,16 +116,28 @@ int main(int argc, char* argv[])
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
problem_index = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg2: time kernel (0=no, 1=yes)" << std::endl;
|
||||
<< "arg2: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg3: problem index (0-35, -1 means all)" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
for(auto& p : problems)
|
||||
for(size_t i = 0; i < problems.size(); i++)
|
||||
{
|
||||
if(problem_index != -1 && problem_index != static_cast<ck::index_t>(i))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
auto& p = problems[i];
|
||||
GemmParams& problem_size = std::get<0>(p);
|
||||
const LayoutConfig& layout_config = std::get<1>(p);
|
||||
const auto& factory = std::get<2>(p);
|
||||
|
||||
@@ -261,6 +261,44 @@ struct TestGemm
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
template <template <class...> class DeviceGemmPtr_,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
bool IsSupportedArgument(DeviceGemmPtr_<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>* gemmPtr,
|
||||
const GemmParams& params = GemmParams{})
|
||||
{
|
||||
auto invoker_ptr = gemmPtr->MakeInvokerPointer();
|
||||
auto argument_ptr = gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(nullptr),
|
||||
static_cast<BDataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
params.M,
|
||||
params.N,
|
||||
params.K,
|
||||
params.StrideA,
|
||||
params.StrideB,
|
||||
params.StrideC,
|
||||
AElementwiseOperation{},
|
||||
BElementwiseOperation{},
|
||||
CElementwiseOperation{});
|
||||
|
||||
return gemmPtr->IsSupportedArgument(argument_ptr.get());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gemm_util
|
||||
|
||||
@@ -1,13 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run_gemm_test()
|
||||
int run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
ck::gemm_util::GemmParams params;
|
||||
ck::index_t instance_index = -1;
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default params
|
||||
}
|
||||
else if(argc == 4 || argc == 5)
|
||||
{
|
||||
params.M = atoi(argv[1]);
|
||||
params.N = atoi(argv[2]);
|
||||
params.K = atoi(argv[3]);
|
||||
params.StrideA = params.M;
|
||||
params.StrideB = params.N;
|
||||
params.StrideC = params.K;
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
instance_index = atoi(argv[4]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1-4: M N K instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
std::cout << "Params (M, N, K, index) " << params.M << " " << params.N << " " << params.K << " "
|
||||
<< instance_index << std::endl;
|
||||
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
|
||||
bool pass = true;
|
||||
|
||||
@@ -24,10 +50,31 @@ int run_gemm_test()
|
||||
const auto gemmPtrs =
|
||||
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
ck::index_t num_instance = 0;
|
||||
for(auto& gemmPtr : gemmPtrs)
|
||||
{
|
||||
pass &= ck::gemm_util::TestGemm<AccDataType>{}(gemmPtr.get());
|
||||
if(instance_index == -1)
|
||||
{
|
||||
pass &= ck::gemm_util::TestGemm<AccDataType>{}(gemmPtr.get(), params);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto test_gemm = ck::gemm_util::TestGemm<AccDataType>{};
|
||||
if(test_gemm.IsSupportedArgument(gemmPtr.get(), params))
|
||||
{
|
||||
if(num_instance == instance_index)
|
||||
{
|
||||
pass &= test_gemm(gemmPtr.get(), params);
|
||||
}
|
||||
num_instance++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(instance_index != -1)
|
||||
{
|
||||
std::cout << "TestGemm_instance (" << instance_index << "/" << num_instance
|
||||
<< "): " << (pass ? "Passed" : "Failed") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
Reference in New Issue
Block a user