[CK Tile] gemm splitk two stage (#2697)

* Fix a typo

* Use std::variant to call run_gemm_example_with_layouts with the available layout variant combinations

* Use a unified run_gemm_example_prec_type for basic gemm and universal gemm

* Factor out run_gemm_example_prec_type

* Refactor argument parsing in gemm_splitk_two_stage_reduce.cpp

* Parse arguments outside of create_args

* Move the gemm operators to separate structs to facilitate their reuse

* Move the invokers to separate files to facilitate their reuse

* Rename the invoker files for consistency with the examples that use them

* Add fp32 support to the elementwise examples, and produce an error message for unsupported types

* Get rid of four unused variables

* Make two variables const

* Add support for different input-output type combinations in elementwise examples

* Test support for different input and output types in elementwise examples

* Add support for different operations in the elementwise unary tests

* Add support for UnaryConvert in the elementwise unary tests

* Add support for bf16 in elementwise examples, excluding unsupported type combinations

* Make some operator parameters const in ElementWiseKernel

* Remove some unnecessary include statements

* Implement a two-stage GEMM that does a type conversion in the second stage using the elementwise kernel

* Clear workspace instead of output when flushing the cache in SplitKTwoStageInvoker::gemm

* Fix formatting issues reported by clang

* Add back CK_TILE_USE_WMMA related changes

* Use the right prec type for bf16 in the universal GEMM and two stage split K examples

* Add some brackets

* Add some brackets

* Separate the clearing of the GEMM output memory from the cache flushing in the universal GEMM example

* Separate the clearing of the GEMM output memory from the cache flushing in the split K two stage example

* Fix formatting

* No need to call SetZero on ws_m_n_dev_buf here, as clear_gemm_output now does this as part of the kernel preprocessing

* Add fp16 data type to splitk two stage example

* Add preprocessing with optional cache flushing and clearing of output for k_batch > 1 to the basic GEMM example
This commit is contained in:
SamiAario-AMD
2025-09-04 14:33:44 +03:00
committed by GitHub
parent e2d28a92af
commit 1acd8e041c
21 changed files with 1245 additions and 782 deletions

View File

@@ -146,20 +146,7 @@ void permute_vectors_i4x4_b(Tensor& tensor)
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
template <typename GemmConfig,
typename Invoker,
typename ADataType,
typename BDataType,
typename DsDataType,
@@ -200,36 +187,36 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float ave_time;
if(persistent)
{
ave_time = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
true,
CDEElementWise>(
ave_time = Invoker::template gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
true,
CDEElementWise>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
}
else
{
ave_time = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
ave_time = Invoker::template gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
@@ -274,6 +261,7 @@ bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
}
template <typename GemmConfig,
typename Invoker,
typename ADataType,
typename BDataType = ADataType,
typename CDataType = ADataType,
@@ -399,6 +387,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
c_m_n_dev_result.SetZero();
float ave_time = invoke_gemm<GemmConfig,
Invoker,
ADataType,
BDataType,
ck_tile::tuple<>,