mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-25 07:14:37 +00:00
[CK-Tile] Refactor base pipeline usage (#3251)
* initial poc
* factor out common parts in operator()
* cv4
* rest of the universal gemm pipelines
* fix test
* remove boilerplate from tile engine
* fix example
* fix example
* format
* fix tests build for gemm
* remove base pipeline codegen from gemm instance builder
* unify v3 logic with the rest of universal gemm pipelines
* fix build for multi abd test
* fix test gemm multi d
* fix build for weight preshuffle
* fix grouped gemm test
* fix grouped gemm multi d test
* fix grouped gemm preshuffle
* fix grouped gemm example except for quant
* fix gemm preshuffle
* fix splitk 2 stage example
* fix batched gemm example
* fix multid example
* fix multiabd example
* fix batched gemm test
* fixup
* fix examples build
* fix grouped gemm test build
* fix smoke builder
[ROCm/composable_kernel commit: d184eed823]
This commit is contained in:
@@ -90,24 +90,9 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
|
||||
ck_tile::index_t K_total = 1;
|
||||
for(ck_tile::index_t i = NumDimG + NumDimM; i < NumDimG + NumDimM + NumDimK; ++i)
|
||||
{
|
||||
K_total *= args.A_dims[i];
|
||||
}
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_total);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
const auto Run = [&]() {
|
||||
constexpr auto memory_operation =
|
||||
ck_tile::memory_operation_enum::set; // Always set (no atomic_add)
|
||||
|
||||
@@ -116,9 +101,7 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
@@ -166,14 +149,10 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
|
||||
|
||||
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s, kernel);
|
||||
|
||||
return ave_time;
|
||||
return ck_tile::launch_kernel(s, kernel);
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
return Run();
|
||||
}
|
||||
|
||||
#define HANDLE_CASE(G, M, N, K) \
|
||||
|
||||
Reference in New Issue
Block a user