Merge commit '4011dbfec31a711aaa4c1071c31bdc55f9b7974a' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-11 23:13:32 +00:00
parent f7eba31069
commit b964d752f0
2 changed files with 44 additions and 88 deletions

View File

@@ -452,34 +452,23 @@ struct SelectedKernel {{
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
static float launch(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
const ck_tile::index_t k_grain = args.k_batch * TileK;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC>,
scheduler>;
float ave_time{{0}};
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
const auto Run = [&](const auto memory_operation_) {{
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC>,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
// Epilogue
// Epilogue
"""
# Add epilogue configuration based on type
@@ -552,29 +541,18 @@ struct SelectedKernel {{
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
ave_time = ck_tile::launch_kernel(
return ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
return ave_time;
}};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
if(args.k_batch == 1) {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
if(args.k_batch == 1) {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}}
}};
"""

View File

@@ -484,35 +484,24 @@ struct SelectedKernel {{
using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")}<GemmPipelineProblem>;
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
const ck_tile::index_t k_grain = args.k_batch * TileK;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{{0}};
constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")};
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC,
UseStructuredSparsity, UsePersistentKernel,
NumWaveGroups, Preshuffle>,
scheduler>;
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {{
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC,
UseStructuredSparsity, UsePersistentKernel,
NumWaveGroups, Preshuffle>,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}<UniversalGemmProblem>;
// Epilogue
"""
@@ -590,29 +579,18 @@ struct SelectedKernel {{
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
ave_time = ck_tile::launch_kernel(
return ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
if(args.k_batch == 1) {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
if(args.k_batch == 1) {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}}
}};
"""