[CK-Tile] fixup codegen for tile engine ops gemm multid and gemm preshuffle (#3383)

* fixup gemm multi-d and preshuffle in tile engine codegen

---------

Co-authored-by: Thrupti Raj Lakshmana Gowda <thruptiraj.lakshmanagowda@amd.com>
This commit is contained in:
Max Podkorytov
2025-12-11 14:23:43 -08:00
committed by GitHub
parent ff194a4271
commit 4011dbfec3
2 changed files with 44 additions and 88 deletions

View File

@@ -452,34 +452,23 @@ struct SelectedKernel {{
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
static float launch(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
const ck_tile::index_t k_grain = args.k_batch * TileK;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC>,
scheduler>;
float ave_time{{0}};
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
const auto Run = [&](const auto memory_operation_) {{
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC>,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
// Epilogue
// Epilogue
"""
# Add epilogue configuration based on type
@@ -552,29 +541,18 @@ struct SelectedKernel {{
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
ave_time = ck_tile::launch_kernel(
return ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
return ave_time;
}};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
if(args.k_batch == 1) {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
if(args.k_batch == 1) {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}}
}};
"""

View File

@@ -484,35 +484,24 @@ struct SelectedKernel {{
using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")}<GemmPipelineProblem>;
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
const ck_tile::index_t k_grain = args.k_batch * TileK;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{{0}};
constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")};
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC,
UseStructuredSparsity, UsePersistentKernel,
NumWaveGroups, Preshuffle>,
scheduler>;
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {{
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC,
UseStructuredSparsity, UsePersistentKernel,
NumWaveGroups, Preshuffle>,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}<UniversalGemmProblem>;
// Epilogue
"""
@@ -590,29 +579,18 @@ struct SelectedKernel {{
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
ave_time = ck_tile::launch_kernel(
return ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
if(args.k_batch == 1) {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
if(args.k_batch == 1) {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}}
}};
"""