[CK_TILE] Support multi-config in tile_example_gemm_universal (#2240)

* [CK_TILE] Support multi-config in tile_example_gemm_universal

Add GemmConfig in run_gemm_example to support multiple tile config.
- It is useful when use you need compare gemm perf with different tile/pipeline config
- we also can use it simplify the code for wmma support in the furture.

* [CK_TILE] Support multi-config in tile_example_gemm_universal

Address review comments

* rebase code and fix clang format.

* fix clang format

* support pipeline v5.

* fix merge conflict

* address review comment

* add missing file

* address review comment v2

* fix build error
This commit is contained in:
linqunAMD
2025-06-18 08:27:46 +08:00
committed by GitHub
parent df54667102
commit 0eb8974502
6 changed files with 306 additions and 155 deletions

View File

@@ -12,7 +12,8 @@
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -22,7 +23,7 @@ template <typename ADataType,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
{
@@ -140,12 +141,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
}
else
@@ -156,24 +157,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
}
else
{
if(a_layout == "R" && b_layout == "R")
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
}
else
@@ -211,15 +212,19 @@ int run_gemm_example(int argc, char* argv[])
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
{
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
#endif
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");