[CK-Tile] Refactor base pipeline usage (#3251)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

[ROCm/composable_kernel commit: d184eed823]
This commit is contained in:
Max Podkorytov
2025-12-04 11:45:49 -08:00
committed by GitHub
parent c25f2909d0
commit 1f7aa130eb
37 changed files with 1012 additions and 1836 deletions

View File

@@ -90,24 +90,9 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
ck_tile::index_t K_total = 1;
for(ck_tile::index_t i = NumDimG + NumDimM; i < NumDimG + NumDimM + NumDimK; ++i)
{
K_total *= args.A_dims[i];
}
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_total);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
const auto Run = [&]() {
constexpr auto memory_operation =
ck_tile::memory_operation_enum::set; // Always set (no atomic_add)
@@ -116,9 +101,7 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
scheduler>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
@@ -166,14 +149,10 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
ave_time = ck_tile::launch_kernel(s, kernel);
return ave_time;
return ck_tile::launch_kernel(s, kernel);
};
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
return ave_time;
return Run();
}
#define HANDLE_CASE(G, M, N, K) \