mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-27 08:25:46 +00:00
[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)
* initial poc
* factor out common parts in operator()
* cv4
* rest of the universal gemm pipelines
* fix test
* remove boilerplate from tile engine
* fix example
* fix example
* format
* fix tests build for gemm
* remove base pipeline codegen from gemm instance builder
* unify v3 logic with the rest of universal gemm pipelines
* fix build for multi abd test
* fix test gemm multi d
* fix build for weight preshuffle
* fix grouped gemm test
* fix grouped gemm multi d test
* fix grouped gemm preshuffle
* fix grouped gemm example except for quant
* fix gemm preshuffle
* fix splitk 2 stage example
* fix batched gemm example
* fix multid example
* fix multiabd example
* fix batched gemm test
* fixup
* fix examples build
* fix grouped gemm test build
* fix smoke builder
* hacky poc
* fix tile engine
* kill the lambda
* maybe fix test build
* more fixes
* clang-format
* save temp
* clang-format
* mostly fix examples
* clang-format
* remove dead code
* more cleanup
* fix fmha bwd build (default epilogue set/add appears to be broken)
* fix default epilogue tests but not correctness
* clang-format
* fix bquant
* clang-format
* cleanup dead code
* rearrange make windows for readability
* restore changes to IsSupportedArgument
* fix smoke-builder
* clang-format
* fixup rename class
* build fixes
* clang-format
* fix builder
* fixup
* remove set from builder tests
* fix test
* clang-format
* re-refactor the kernels
* clang-format
* fix header license
* remove memory operation from conv bwd test
* clang-format
* clang-format example,include
* clang-format test
* build fixes
* clang-format
* solve compilation error
* fix the CI
* solve compilation error
* clang format
* solve merge conflict
* solve merge conflict
* solve the gfx11 error
* solve test error
* moar build fixes
* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope
---------
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
[ROCm/composable_kernel commit: e339101e9c]
This commit is contained in:
@@ -48,112 +48,87 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS,
|
||||
GemmConfiguration::PRESHUFFLE>;
|
||||
|
||||
const auto runKernel = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfiguration::SCHEDULER>;
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfiguration::SCHEDULER>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation.value,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS>>;
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
}
|
||||
|
||||
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
auto reset_data_buffers = [&]() {
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}
|
||||
|
||||
auto reset_data_buffers = [&]() {
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}
|
||||
};
|
||||
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float average_time =
|
||||
ck_tile::launch_kernel_time_mask(stream_config,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
|
||||
Kernel{}, grids, blocks, 0, kernel_args));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile =
|
||||
kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{average_time, num_wgs_per_tile};
|
||||
};
|
||||
|
||||
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
|
||||
{
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the
|
||||
// same output tile in the C tensor, so we must
|
||||
// atomic add the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else // We are using ck_tile::StreamKReductionStrategy::Reduction
|
||||
{
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// In this case, there is only ever 1 WG writing
|
||||
// final results to each macro tile in the C
|
||||
// tensor, so we can do a set.
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float average_time =
|
||||
ck_tile::launch_kernel_time_mask(stream_config,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
|
||||
Kernel{}, grids, blocks, 0, kernel_args));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{average_time, num_wgs_per_tile};
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
Reference in New Issue
Block a user