[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent ec23be0b9d
commit e339101e9c
68 changed files with 4198 additions and 4298 deletions

View File

@@ -222,19 +222,13 @@ struct StreamKKernel
const index_t block_idx_n,
const index_t k_size)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
// Create block windows using specialized methods
const auto& as_block_window =
UniversalGemmKernel::MakeABlockWindows(as_ptr, kargs, k_size, block_idx_m);
const auto& bs_block_window =
UniversalGemmKernel::MakeBBlockWindows(bs_ptr, kargs, k_size, block_idx_n);
const auto& ds_block_window =
UniversalGemmKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
@@ -243,6 +237,7 @@ struct StreamKKernel
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
bs_block_window[UniversalGemmKernel::I0],
num_loop,
@@ -253,7 +248,9 @@ struct StreamKKernel
if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
auto c_block_window =
UniversalGemmKernel::template MakeCBlockWindows<TilePartitioner::MemoryOperation>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
@@ -525,21 +522,13 @@ struct StreamKKernel
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
UniversalGemmKernel::template MakeGemmTensorViews<
EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
const auto& gemm_pad_views =
UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
// Create block windows using specialized methods
const auto& as_block_window =
UniversalGemmKernel::MakeABlockWindows({a_ptr}, kargs, k_size, i_m);
const auto& bs_block_window =
UniversalGemmKernel::MakeBBlockWindows({b_ptr}, kargs, k_size, i_n);
const auto& ds_block_window =
UniversalGemmKernel::MakeDBlockWindows({/*ds_ptr*/}, kargs, i_m, i_n);
// Since num_loop can vary per WG and per iteration of the Stream-K while loop,
// we compute has_hot_loop and tail_num here. This is a similar pattern used by
@@ -548,6 +537,7 @@ struct StreamKKernel
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
bs_block_window[UniversalGemmKernel::I0],
num_loop_sk,
@@ -594,7 +584,8 @@ struct StreamKKernel
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
@@ -617,7 +608,8 @@ struct StreamKKernel
// tensor.
if(tile_started && !partner_in_tile)
{
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
break;

View File

@@ -27,6 +27,9 @@ struct StreamKTilePartitionerBase
static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
static constexpr auto MemoryOperation = (ReductionStrategy == StreamKReductionStrategy::Atomic)
? memory_operation_enum::atomic_add
: memory_operation_enum::set;
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);