Lwpck 3548 gemm test cleanups (#2717)

* Remove some unnecessary calls to create_args in basic and universal GEMM tests

* Remove unnecessary include statements in universal GEMM tests

* Improve compilation time of basic GEMM tests by only compiling the precision variants that we need

* Universal GEMM PrecType should be the same as CDataType

* Improve compilation time of universal GEMM tests by only compiling the precision variants that we need

* Revert to constexpr when defining some constants

[ROCm/composable_kernel commit: 5e85c38d7d]
This commit is contained in:
SamiAario-AMD
2025-08-26 13:25:48 +03:00
committed by GitHub
parent 92037686ae
commit bf69dd72c2
11 changed files with 108 additions and 218 deletions

View File

@@ -2,4 +2,4 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"
int main() { return run_gemm_combinations("bf16"); }
int main() { return run_gemm_combinations<ck_tile::bf16_t>(); }

View File

@@ -2,4 +2,4 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"
int main() { return run_gemm_combinations("bf8"); }
int main() { return run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(); }

View File

@@ -2,4 +2,4 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"
int main() { return run_gemm_combinations("fp16"); }
int main() { return run_gemm_combinations<ck_tile::half_t>(); }

View File

@@ -2,4 +2,4 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"
int main() { return run_gemm_combinations("bf8"); }
int main() { return run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(); }

View File

@@ -131,7 +131,9 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
}
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
bool run_gemm_test_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
@@ -141,12 +143,12 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
arg_parser, Col{}, Col{}, Row{});
}
else
{
@@ -159,22 +161,22 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Row{}, Row{});
arg_parser, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Row{}, Row{});
arg_parser, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
arg_parser, Col{}, Col{}, Row{});
}
else
{
@@ -183,60 +185,26 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
}
}
template <typename APrecType, typename BPrecType, typename CPrecType>
bool run_gemm_test(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp16")
{
return run_gemm_test_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_test_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_test_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
return run_gemm_test_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
{
return run_gemm_test_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
return run_gemm_test_prec_type<APrecType, BPrecType, CPrecType>(a_layout, b_layout, arg_parser);
}
int run_gemm_combinations(std::string const& data_type)
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
int run_gemm_combinations()
{
// Define possible values for each parameter
std::vector<std::string> m_values = {"128", "1024"};
std::vector<std::string> n_values = {"128", "2048"};
std::vector<std::string> k_values = {"64", "128"};
std::vector<std::string> prec_values = {data_type};
std::vector<std::string> m_values = {"128", "1024"};
std::vector<std::string> n_values = {"128", "2048"};
std::vector<std::string> k_values = {"64", "128"};
// We'll store all our arguments as strings first
std::vector<std::string> arg_strings = {"./bin/tile_example_gemm_basic",
@@ -246,13 +214,12 @@ int run_gemm_combinations(std::string const& data_type)
"-stride_a=0",
"-stride_b=0",
"-stride_c=0",
"", // prec placeholder
"-v=2",
"-warmup=0",
"-repeat=1"};
// Create an array of const char pointers for argv
constexpr size_t ARG_COUNT = 11;
constexpr size_t ARG_COUNT = 10;
constexpr size_t ARG_MAX_LEN = 64;
char args[ARG_COUNT][ARG_MAX_LEN];
char* argv[ARG_COUNT];
@@ -271,39 +238,35 @@ int run_gemm_combinations(std::string const& data_type)
{
arg_strings[3] = "-k=" + k;
for(const auto& prec : prec_values)
// Set up the argv array with pointers to the string data
for(size_t i = 0; i < ARG_COUNT; i++)
{
arg_strings[7] = "-prec=" + prec;
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
argv[i] = args[i];
}
// Set up the argv array with pointers to the string data
for(size_t i = 0; i < ARG_COUNT; i++)
{
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
argv[i] = args[i];
}
std::cout << "Arguments received: ";
for(size_t i = 1; i < ARG_COUNT; ++i)
{
std::cout << argv[i] << " ";
}
std::cout << std::endl;
std::cout << "Arguments received: ";
for(size_t i = 1; i < ARG_COUNT; ++i)
{
std::cout << argv[i] << " ";
}
std::cout << std::endl;
// Call the function with the current configuration
try
{
is_success = run_gemm_test(ARG_COUNT, argv) && is_success;
}
catch(const ArgumentsNotSupportedException& e)
{
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
// ArgumentsNotSupportedException is not an error. Do not change is_success
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
is_success = false;
}
// Call the function with the current configuration
try
{
is_success = run_gemm_test<APrecType, BPrecType, CPrecType>(ARG_COUNT, argv) &&
is_success;
}
catch(const ArgumentsNotSupportedException& e)
{
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
// ArgumentsNotSupportedException is not an error. Do not change is_success
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
is_success = false;
}
}
}

View File

@@ -256,16 +256,11 @@ template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout>
bool run_gemm_test_with_layouts(int argc,
char* argv[],
bool run_gemm_test_with_layouts(ck_tile::ArgParser& arg_parser,
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");

View File

@@ -1,16 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstddef>
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <string>
#include "ck_tile/host.hpp"
#include "test_gemm_pipeline_smoke_util.hpp"
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations("bf16"); }
int main() { return run_gemm_combinations<ck_tile::bf16_t>(); }

View File

@@ -1,16 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstddef>
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <string>
#include "ck_tile/host.hpp"
#include "test_gemm_pipeline_smoke_util.hpp"
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations("bf8"); }
int main() { return run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(); }

View File

@@ -1,16 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstddef>
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <string>
#include "ck_tile/host.hpp"
#include "test_gemm_pipeline_smoke_util.hpp"
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations("fp16"); }
int main() { return run_gemm_combinations<ck_tile::half_t>(); }

View File

@@ -1,16 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstddef>
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <string>
#include "ck_tile/host.hpp"
#include "test_gemm_pipeline_smoke_util.hpp"
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations("fp8"); }
int main() { return run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(); }

View File

@@ -200,7 +200,9 @@ template <typename GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
bool run_gemm_test_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
@@ -210,12 +212,12 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
arg_parser, Col{}, Col{}, Row{});
}
else
{
@@ -228,22 +230,22 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Row{}, Row{});
arg_parser, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Row{}, Row{});
arg_parser, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
arg_parser, Col{}, Col{}, Row{});
}
else
{
@@ -252,69 +254,27 @@ bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int arg
}
}
template <template <typename PreType> typename GemmConfig>
template <typename GemmConfig, typename APrecType, typename BPrecType, typename CPrecType>
bool run_gemm_test(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp16")
{
return run_gemm_test_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_test_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
return run_gemm_test_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
{
return run_gemm_test_prec_type<GemmConfig<ck_tile::half_t>,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
return run_gemm_test_prec_type<GemmConfig, APrecType, BPrecType, CPrecType>(
a_layout, b_layout, arg_parser);
}
int run_gemm_combinations(std::string const& data_type)
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
int run_gemm_combinations()
{
// Define possible values for each parameter
std::vector<std::string> m_values = {"512", "1024"};
std::vector<std::string> n_values = {"512", "2048"};
std::vector<std::string> k_values = {"512", "1024"};
std::vector<std::string> prec_values = {data_type};
std::vector<std::string> m_values = {"512", "1024"};
std::vector<std::string> n_values = {"512", "2048"};
std::vector<std::string> k_values = {"512", "1024"};
// We'll store all our arguments as strings first
std::vector<std::string> arg_strings = {"./bin/tile_example_gemm_universal",
@@ -324,13 +284,12 @@ int run_gemm_combinations(std::string const& data_type)
"-stride_a=0",
"-stride_b=0",
"-stride_c=0",
"", // prec placeholder
"-v=2",
"-warmup=0",
"-repeat=1"};
// Create an array of const char pointers for argv
constexpr size_t ARG_COUNT = 11;
constexpr size_t ARG_COUNT = 10;
constexpr size_t ARG_MAX_LEN = 64;
char args[ARG_COUNT][ARG_MAX_LEN];
char* argv[ARG_COUNT];
@@ -349,42 +308,43 @@ int run_gemm_combinations(std::string const& data_type)
{
arg_strings[3] = "-k=" + k;
for(const auto& prec : prec_values)
// Set up the argv array with pointers to the string data
for(size_t i = 0; i < ARG_COUNT; i++)
{
arg_strings[7] = "-prec=" + prec;
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
argv[i] = args[i];
}
// Set up the argv array with pointers to the string data
for(size_t i = 0; i < ARG_COUNT; i++)
{
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
argv[i] = args[i];
}
std::cout << "Arguments received: ";
for(size_t i = 1; i < ARG_COUNT; ++i)
{
std::cout << argv[i] << " ";
}
std::cout << std::endl;
std::cout << "Arguments received: ";
for(size_t i = 1; i < ARG_COUNT; ++i)
{
std::cout << argv[i] << " ";
}
std::cout << std::endl;
// Call the function with the current configuration
try
{
is_success =
run_gemm_test<GemmConfigComputeV3>(ARG_COUNT, argv) && is_success;
is_success =
run_gemm_test<GemmConfigComputeV3_2>(ARG_COUNT, argv) && is_success;
}
catch(const ArgumentsNotSupportedException& e)
{
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
// ArgumentsNotSupportedException is not an error. Do not change is_success
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
is_success = false;
}
// Call the function with the current configuration
try
{
is_success = run_gemm_test<GemmConfigComputeV3<CPrecType>,
APrecType,
BPrecType,
CPrecType>(ARG_COUNT, argv) &&
is_success;
is_success = run_gemm_test<GemmConfigComputeV3_2<CPrecType>,
APrecType,
BPrecType,
CPrecType>(ARG_COUNT, argv) &&
is_success;
}
catch(const ArgumentsNotSupportedException& e)
{
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
// ArgumentsNotSupportedException is not an error. Do not change is_success
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
is_success = false;
}
}
}