CK-Tile first draft of universal block gemm with interwave & intrawave scheduler (#1676)

* Block universal gemm.

* Universal block gemm with interwave scheduler - draft.

* Refactoring

* Move a/b_warp_tiles into BlockGemmImpl
* set BlockGemmImpl as a class member

* Change tile size for more suitable to memory bound cases.

* Introduce kKPerThread to WarpGemm

* Add documentation comment.

* Fix Interwave scheduler block gemm.

* Add compute/memory friendly tile configuration.

* Clean

* New tile configurations in gemm mem example.

* Add more static checks and fix loop order in block gemm.

* Add more static checks and use warp gemm mfma dispatcher.

* Add default scheduler block gemm.

* Remove logging in example.
This commit is contained in:
Adam Osewski
2024-11-26 08:45:14 +01:00
committed by GitHub
parent 440e28b08f
commit b6bcd76d88
11 changed files with 779 additions and 56 deletions

View File

@@ -31,15 +31,13 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Gemm{MemBoundPipeline}"};
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
@@ -114,7 +112,6 @@ int run_gemm_example_with_layouts(int argc,
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
// TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
@@ -202,14 +199,15 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work. else if(a_layout == "C" && b_layout == "C")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");