mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK-Tile] fixup codegen for tile engine ops gemm multid and gemm preshuffle (#3383)
* fixup gemm multi-d and preshuffle in tile engine codegen --------- Co-authored-by: Thrupti Raj Lakshmana Gowda <thruptiraj.lakshmanagowda@amd.com>
This commit is contained in:
@@ -452,34 +452,23 @@ struct SelectedKernel {{
|
||||
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
|
||||
|
||||
static float launch(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
|
||||
const ck_tile::index_t k_grain = args.k_batch * TileK;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC>,
|
||||
scheduler>;
|
||||
|
||||
float ave_time{{0}};
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
|
||||
const auto Run = [&](const auto memory_operation_) {{
|
||||
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC>,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
|
||||
|
||||
// Epilogue
|
||||
// Epilogue
|
||||
"""
|
||||
|
||||
# Add epilogue configuration based on type
|
||||
@@ -552,29 +541,18 @@ struct SelectedKernel {{
|
||||
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
return ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||||
if(args.k_batch == 1) {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}} else {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
}};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
if(args.k_batch == 1) {{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}} else {{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user