[CK_TILE] Support multi-config in tile_example_gemm_universal (#2240)

* [CK_TILE] Support multi-config in tile_example_gemm_universal

Add GemmConfig in run_gemm_example to support multiple tile config.
- It is useful when use you need compare gemm perf with different tile/pipeline config
- we also can use it simplify the code for wmma support in the furture.

* [CK_TILE] Support multi-config in tile_example_gemm_universal

Address review comments

* rebase code and fix clang format.

* fix clang format

* support pipeline v5.

* fix merge conflict

* address review comment

* add missing file

* address review comment v2

* fix build error
This commit is contained in:
linqunAMD
2025-06-18 08:27:46 +08:00
committed by GitHub
parent df54667102
commit 0eb8974502
6 changed files with 306 additions and 155 deletions

View File

@@ -13,7 +13,8 @@
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -23,7 +24,7 @@ template <typename ADataType,
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
{
@@ -45,7 +46,8 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout>;
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -61,7 +63,8 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
@@ -75,7 +78,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -87,7 +90,8 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -108,7 +112,6 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -205,7 +208,10 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
return ave_time;
}
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
template <typename GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
@@ -215,12 +221,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
}
else
@@ -233,22 +239,22 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
{
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Col{}, Col{}, Row{});
}
else
@@ -258,6 +264,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
}
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -270,31 +277,43 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
#endif
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
@@ -305,7 +324,7 @@ int main(int argc, char* argv[])
{
try
{
return !run_gemm_example(argc, argv);
return !run_gemm_example<GemmConfigComputeV3>(argc, argv);
}
catch(const std::runtime_error& e)
{