[CK_tile] Add rotating buffer feature for universal gemm (#2200)

* Add rotating buffer feature for universal gemm

* adding changes in tile_engine

* Updated code to merge kernel_launch

* removing comments

* Enable rotating buffer changes to flatmm

* Created diff launch_kernel function for rotating buffer

* Simplfied calculation using macros

* merge code with new changes in tile_engine

* clang formatted

* Redefine macros
This commit is contained in:
Khushbu Agarwal
2025-05-27 23:00:58 -07:00
committed by GitHub
parent c52649ad57
commit 99857e10e6
17 changed files with 409 additions and 74 deletions

View File

@@ -273,9 +273,52 @@ struct GemmKernel {{
<< std::endl;
}}
ave_time = ck_tile::launch_kernel(stream,
if(stream.flush_cache_)
{{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
auto is_row_major = [](auto layout_) {{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{{}};
}};
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{{}})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{{}})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {{
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
}};
ave_time = ck_tile::launch_kernel_preprocess(
stream,
run_flush_cache,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
Kernel{{}}, grids, blocks, 0, kargs));
}}
else{{
ave_time = ck_tile::launch_kernel(stream,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
Kernel{{}}, grids, blocks, 0, kargs));
}}
return ave_time;
}};