[CK_TILE] Support multi-config in tile_example_gemm_universal (#2240)

* [CK_TILE] Support multi-config in tile_example_gemm_universal

Add GemmConfig in run_gemm_example to support multiple tile config.
- It is useful when use you need compare gemm perf with different tile/pipeline config
- we also can use it simplify the code for wmma support in the furture.

* [CK_TILE] Support multi-config in tile_example_gemm_universal

Address review comments

* rebase code and fix clang format.

* fix clang format

* support pipeline v5.

* fix merge conflict

* address review comment

* add missing file

* address review comment v2

* fix build error
This commit is contained in:
linqunAMD
2025-06-18 08:27:46 +08:00
committed by GitHub
parent df54667102
commit 0eb8974502
6 changed files with 306 additions and 155 deletions

View File

@@ -30,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename Tensor,
template <typename GemmConfig,
typename Tensor,
typename ADataType,
typename BDataType,
typename AccDataType,
@@ -63,11 +64,12 @@ void permute_tensor_b(Tensor& tensor)
AccDataType,
GemmShape,
GemmUniversalTraits,
GEMM_PIPELINE_SCHEDULER,
GemmConfig::Scheduler,
true,
ck_tile::TailNumber::Full>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;
const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1);
@@ -144,7 +146,22 @@ void permute_vectors_i4x4_b(Tensor& tensor)
}
}
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s);
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -184,7 +201,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float ave_time;
if(persistent)
{
ave_time = gemm<ADataType,
ave_time = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
@@ -199,7 +217,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
}
else
{
ave_time = gemm<ADataType,
ave_time = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
@@ -232,7 +251,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time;
}
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename BDataType = ADataType,
typename CDataType = ADataType,
typename ALayout,
@@ -312,7 +332,8 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PermuteB)
{
permute_tensor_b<decltype(b_k_n_dev),
permute_tensor_b<GemmConfig,
decltype(b_k_n_dev),
ADataType,
BDataType,
AccDataType,
@@ -338,7 +359,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType,
invoke_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,