[CK-Tile] Refactor base pipeline usage (#3251)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder
This commit is contained in:
Max Podkorytov
2025-12-04 11:45:49 -08:00
committed by GitHub
parent d9d4c9c3df
commit d184eed823
37 changed files with 1012 additions and 1836 deletions

View File

@@ -46,14 +46,6 @@ struct SplitKTwoStageInvoker
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -67,41 +59,22 @@ struct SplitKTwoStageInvoker
Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
constexpr auto scheduler = GemmConfig::Scheduler;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -230,7 +203,7 @@ struct SplitKTwoStageInvoker
preprocess = clear_gemm_output;
}
ave_time = ck_tile::launch_kernel_time_mask(
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
@@ -244,21 +217,15 @@ struct SplitKTwoStageInvoker
ck_tile::make_tuple(args.N, 1), // Output Stride
input_tensors,
static_cast<CDataType*>(c_ptr)));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
}
else
{
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
}
};
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
if(args.k_batch == 1)
{
return Run(MemoryOpSet{});
}
else
{
return Run(MemoryOpAtomicAdd{});
}
}
};

View File

@@ -133,14 +133,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
@@ -154,19 +146,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
// Create base GEMM arguments pointing to workspace instead of final output
// The workspace will store partial results from each K-split
ck_tile::GemmHostArgs base_args(args.a_ptr,
@@ -179,23 +158,18 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
args.stride_A,
args.stride_B,
args.stride_E);
constexpr auto scheduler = GemmConfig::Scheduler;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&]() {
// use SET operation since each K-split writes to separate memory
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -276,29 +250,20 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
return ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
return ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
// For workspace mode, always use SET operation since each K-split writes to separate memory
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
};
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return Run();
}
/**

View File

@@ -33,14 +33,6 @@ struct WeightPreshuffleInvoker
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -54,39 +46,20 @@ struct WeightPreshuffleInvoker
Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
constexpr auto scheduler = GemmConfig::Scheduler;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -139,6 +112,7 @@ struct WeightPreshuffleInvoker
<< "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}"
<< std::endl;
}
float ave_time = 0.f;
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
@@ -183,21 +157,14 @@ struct WeightPreshuffleInvoker
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
throw std::runtime_error("split-k is not supported yet!");
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
if(args.k_batch == 1)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
throw std::runtime_error("split-k is not supported yet!");
}
}
};

View File

@@ -63,14 +63,17 @@ void permute_tensor_b(Tensor& tensor)
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::Scheduler,
true,
ck_tile::TailNumber::Full>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::Scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ADataType,
true>;
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
UniversalGemmProblem>;

View File

@@ -34,14 +34,6 @@ struct UniversalInvoker
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -55,39 +47,22 @@ struct UniversalInvoker
Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
constexpr auto scheduler = GemmConfig::Scheduler;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -173,25 +148,19 @@ struct UniversalInvoker
preprocess = clear_gemm_output;
}
ave_time = ck_tile::launch_kernel_time_mask(
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
}
else
{
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
}
};
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
if(args.k_batch == 1)
{
return Run(MemoryOpSet{});
}
else
{
return Run(MemoryOpAtomicAdd{});
}
}
};