[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>

[ROCm/composable_kernel commit: e339101e9c]
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent 077d75cea0
commit 6cf89bbca9
68 changed files with 4198 additions and 4298 deletions

View File

@@ -48,112 +48,87 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
GemmConfiguration::NUM_WAVE_GROUPS,
GemmConfiguration::PRESHUFFLE>;
const auto runKernel = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccumulatorDataType,
GemmShape,
GemmUniversalTraits,
GemmConfiguration::SCHEDULER>;
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccumulatorDataType,
GemmShape,
GemmUniversalTraits,
GemmConfiguration::SCHEDULER>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE,
UniversalGemmProblem::TransposeC,
memory_operation.value,
GemmConfiguration::NUM_WAVE_GROUPS>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE,
UniversalGemmProblem::TransposeC,
GemmConfiguration::NUM_WAVE_GROUPS>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kernel_args = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
ck_tile::DeviceMem workspace_data(workspace_size);
auto kernel_args = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kernel_args))
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
if(stream_config.log_level_ > 0)
{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
}
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kernel_args))
auto reset_data_buffers = [&]() {
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
if(stream_config.log_level_ > 0)
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}
auto reset_data_buffers = [&]() {
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}
};
std::function<void()> preprocess = reset_data_buffers;
float average_time =
ck_tile::launch_kernel_time_mask(stream_config,
preprocess,
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
Kernel{}, grids, blocks, 0, kernel_args));
ck_tile::index_t num_wgs_per_tile =
kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{average_time, num_wgs_per_tile};
};
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
{
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// Since we are doing stream K, in the case of
// atomics, multiple workgroups may write to the
// same output tile in the C tensor, so we must
// atomic add the results (not set)
ck_tile::memory_operation_enum::atomic_add>{});
}
else // We are using ck_tile::StreamKReductionStrategy::Reduction
{
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// In this case, there is only ever 1 WG writing
// final results to each macro tile in the C
// tensor, so we can do a set.
ck_tile::memory_operation_enum::set>{});
}
std::function<void()> preprocess = reset_data_buffers;
float average_time =
ck_tile::launch_kernel_time_mask(stream_config,
preprocess,
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
Kernel{}, grids, blocks, 0, kernel_args));
ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{average_time, num_wgs_per_tile};
}
#include "run_gemm_example.inc"