[CK-Tile] Refactor base pipeline usage (#3251)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder
This commit is contained in:
Max Podkorytov
2025-12-04 11:45:49 -08:00
committed by GitHub
parent d9d4c9c3df
commit d184eed823
37 changed files with 1012 additions and 1836 deletions

View File

@@ -123,8 +123,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
// for testing purposes, we can hardcode the values here as we what is compatible with
// pipeline
using GemmUniversalTraits =
@@ -140,58 +138,37 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
/*Persistent*/ false,
/*NumWaveGroups*/ 1,
/*Preshuffle*/ true>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Default>;
using GemmPipeline =
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Default,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
@@ -204,7 +181,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
hipMemcpyHostToDevice,
s.stream_id_));
ave_time = ck_tile::launch_kernel(
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
@@ -213,25 +190,18 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
if(gemm_descs[0].k_batch == 1)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
}
private:
@@ -247,8 +217,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
// Enable persistent mode for preshuffle
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits</*kPadM*/ true,
@@ -263,58 +231,36 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
/*Persistent*/ true, // Enable persistent mode
/*NumWaveGroups*/ 1,
/*Preshuffle*/ true>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Default>;
using GemmPipeline =
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Default,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
@@ -327,7 +273,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
hipMemcpyHostToDevice,
s.stream_id_));
ave_time = ck_tile::launch_kernel(
return ck_tile::launch_kernel(
s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
@@ -336,25 +282,18 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
if(gemm_descs[0].k_batch == 1)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
}
public: