[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent ec23be0b9d
commit e339101e9c
68 changed files with 4198 additions and 4298 deletions

View File

@@ -148,10 +148,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
@@ -217,8 +216,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
QuantGemmProblem::TransposeC>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
@@ -287,99 +285,92 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantType,
AQLayout,
BQLayout,
TransposeC,
DoubleSmemBuffer,
Persistent>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GroupedGemKernelParam::kPadM,
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantType,
AQLayout,
BQLayout,
TransposeC,
DoubleSmemBuffer,
Persistent>;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
QuantType == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
TransposeC,
BDataType,
scheduler>>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = memory_operation_.value;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using GemmPipeline = std::conditional_t<
UseGroupedQuant,
std::conditional_t<
QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
QuantGemmProblem::TransposeC>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped ||
QuantType == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
TransposeC,
BDataType,
scheduler>>;
using GemmPipeline = std::conditional_t<
UseGroupedQuant,
std::conditional_t<
QuantType == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GroupedGemKernelParam::M_Warp,
GroupedGemKernelParam::N_Warp,
GroupedGemKernelParam::M_Warp_Tile,
GroupedGemKernelParam::N_Warp_Tile,
GroupedGemKernelParam::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ck_tile::ignore =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GroupedGemKernelParam::kBlockPerCu>(
Kernel{},
@@ -388,10 +379,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
template <typename Layout>