[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent ec23be0b9d
commit e339101e9c
68 changed files with 4198 additions and 4298 deletions

View File

@@ -719,8 +719,8 @@ struct SelectedKernel {{
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
instance_code += f"""
// Kernel type
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel type
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);
@@ -802,8 +802,8 @@ struct SelectedKernel {{
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock, // kM_
TilePartitioner::NPerBlock, // kN_
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_

View File

@@ -481,8 +481,6 @@ struct SelectedKernel {{
GemmUniversalTraits>;
static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{
const auto Run = [&](const auto memory_operation_) {{
constexpr auto memory_operation = memory_operation_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -512,7 +510,6 @@ struct SelectedKernel {{
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
memory_operation, // MemoryOperation_
NumWaveGroups>; // kNumWaveGroups_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
@@ -558,30 +555,12 @@ struct SelectedKernel {{
workspace_data.SetZero();
}}
}};
// Launch kernel
float ave_time = ck_tile::launch_kernel_time_mask(
return ck_tile::launch_kernel_time_mask(
stream,
reset_data_buffers,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
// ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
// return std::make_tuple(ave_time, num_wgs_per_tile);
}};
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == reduction_strategy)
{{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
else // We are using ck_tile::StreamKReductionStrategy::Reduction
{{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}}
}}
}};
"""