diff --git a/CHANGELOG.md b/CHANGELOG.md index af8d965b30..368d1e502d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added support for Multiple D GEMM * Added GEMM pipeline for microscaling (MX) FP8/FP4 data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index de9608bcb4..defeffc2ee 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -14,13 +14,17 @@ template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + bool Persistent, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { if constexpr(Persistent) std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; @@ -53,8 +57,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using CodegenGemmTraits = ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; const auto Run = [&](const auto memory_operation_) { diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index f3d11c751b..6987a2492e 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -252,10 +252,13 @@ auto create_args(int argc, char* argv[]) // host API template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); + bool Persistent = false, + typename CDEElementWise> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index bf455a6415..cc9a825c73 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -146,11 +146,14 @@ void permute_vectors_i4x4_b(Tensor& tensor) template + typename DsLayout, + typename CLayout, + typename CDEElementWise = ck_tile::element_wise::PassThrough> float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -165,41 +168,48 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_repeat, bool persistent) { - ck_tile::GemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_C = stride_C; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + {}, + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C}; float ave_time; if(persistent) { - ave_time = gemm_calc( + ave_time = gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); } else { - ave_time = gemm_calc( + ave_time = gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); } @@ -328,20 +338,27 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm( - a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat, - persistent); + invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat, + persistent); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index fafe40c333..beb6987605 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -15,13 +15,17 @@ template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + typename DsLayout, + typename ELayout, + bool Persistent, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -30,24 +34,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& sequence, GemmConfig::PermuteA, GemmConfig::PermuteB>; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; + ELayout>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits +template float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) @@ -123,12 +132,16 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre tail_number_v>; using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 0999c7ad3b..78d915e873 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 16a31e519a..7d5e1910dd 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -23,7 +23,16 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -44,20 +53,29 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::BatchedGemmHostArgs args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; args.stride_A = stride_A; args.stride_B = stride_B; - args.stride_C = stride_C; + args.stride_E = stride_C; args.batch_stride_A = batch_stride_A; args.batch_stride_B = batch_stride_B; - args.batch_stride_C = batch_stride_C; + args.batch_stride_E = batch_stride_C; args.batch_count = batch_count; - float ave_time = batched_gemm( + float ave_time = batched_gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::string op_name{"Batched Gemm"}; @@ -169,22 +187,30 @@ int run_batched_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_batched_gemm(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - batch_stride_A, - batch_stride_B, - batch_stride_C, - batch_count, - kbatch, - n_warmup, - n_repeat); + invoke_batched_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count, + kbatch, + n_warmup, + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index d1a0458eda..59396a558b 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -1,6 +1,6 @@ # Grouped CShuffle GEMM -This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile. +This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. ## build ``` diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 2a72c6325e..85d75320c5 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -16,7 +16,16 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -template +template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) @@ -130,9 +139,12 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem; auto create_args(int argc, char* argv[]) { @@ -82,7 +83,17 @@ inline std::size_t get_workspace_size(const std::vector& gem return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } -template +template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index a01d8178cc..5ed1219731 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -30,7 +30,17 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(int n_warmup, int n_repeat, int group_count, @@ -44,7 +54,16 @@ float invoke_gemm(int n_warmup, if constexpr(!Persistent) { // Regular version of grouped gemm - ave_time = grouped_gemm( + ave_time = grouped_gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, gemm_workspace.GetDeviceBuffer()); @@ -64,16 +83,18 @@ float invoke_gemm(int n_warmup, const bool splitk = args[0].k_batch > 1; for(const auto& arg : args) { - kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.c_ptr, - arg.M, - arg.N, - arg.K, - arg.stride_A, - arg.stride_B, - arg.stride_C, - arg.k_batch}); + kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr, + arg.b_ptr, + {}, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.stride_A, + arg.stride_B, + {}, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, @@ -219,10 +240,19 @@ int run_grouped_gemm_example_with_layouts(int argc, void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); gemm_descs.push_back( - {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + {p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]}); } - invoke_gemm(warmup, repeat, group_count, gemm_descs); + invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout, + Persistent>(warmup, repeat, group_count, gemm_descs); for(int i = 0; i < group_count; i++) { diff --git a/example/ck_tile/19_gemm_multi_d/CMakeLists.txt b/example/ck_tile/19_gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000..e2e68b325a --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/CMakeLists.txt @@ -0,0 +1 @@ +add_executable(tile_example_gemm_multi_d_fp16 EXCLUDE_FROM_ALL gemm_multi_d_fp16.cpp) diff --git a/example/ck_tile/19_gemm_multi_d/README.md b/example/ck_tile/19_gemm_multi_d/README.md new file mode 100644 index 0000000000..7e8cd87546 --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/README.md @@ -0,0 +1,35 @@ +#Multiple D GEMM + +This folder contains example for Multiple D GEMM using ck_tile tile-programming implementation. + +## build +``` +#in the root of ck_tile +mkdir build && cd build +#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \ + leave it blank +sh ../script/cmake-ck-dev.sh ../ +#The basic pipeline method on the gemm calculation +make tile_example_gemm_multi_d_fp16 -j +``` +This will result in an executable `build/bin/tile_example_gemm_multi_d_fp16` + +## example +``` +args: + -m M dimensions - (Default: 3840) + -n N dimensions - (Default: 4096) + -k K dimensions - (Default: 4096) +-a_layout Tensor A layout (default:R) +-b_layout Tensor B layout (default:C) +-ds_layout Tensor D layout (default:R) +-e_layout Tensor E layout (default:R) +-stride_a Tensor A strides - (Default: 0) +-stride_b Tensor B strides - (Default: 0) +-stride_e Tensor C strides - (Default: 0) +-stride_ds Tensor D strides - (Default: 0) +-validate 0. No validation, 1. Validation on GPU. (Default: 1) + -warmup Number of iterations before benchmark the kernel. (Default: 10) + -repeat Number of iterations to benchmark the kernel. (Default: 100) + -kbatch kbatch for SplitK. (Default 1) +``` diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp new file mode 100644 index 0000000000..6c5ca08426 --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "gemm_multi_d_fp16.hpp" +#include "utils.hpp" + +template +auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& s) -> float +{ +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Memory friendly for Interwave scheduler + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + constexpr bool DoubleSmemBuffer = false; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = true; +#endif + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + if(has_hot_loop) + { +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" << tail_num + << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + if(tail_num == ck_tile::TailNumber::One) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + auto check_tail = [&](auto... TNs) { + (try_run(tail_num), ...); + }; + + check_tail(ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}); + +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + if(tail_num == ck_tile::TailNumber::Three) + { + RunSplitk( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +#endif + } + else + { + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} + +#include "run_gemm_multi_d_fp16_example.inc" + +int main(int argc, char* argv[]) { return !run_multiple_d_gemm_example(argc, argv); } diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp new file mode 100644 index 0000000000..3ce3965e56 --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + +using ADataType = ck_tile::half_t; +using BDataType = ck_tile::half_t; +using D0DataType = ck_tile::half_t; +using D1DataType = ck_tile::half_t; +using EDataType = ck_tile::half_t; +using DsDataType = ck_tile::tuple; +using AccDataType = float; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "4096", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Col by default") + .insert("ds_layout", "R", "Ds tensor data layout - Row by default") + .insert("e_layout", "R", "E tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_ds", "0", "Tensor Ds stride") + .insert("stride_e", "0", "Tensor E stride") + .insert("v", "1", "0. No validation, 1. Validation on GPU") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("kbatch", "1", "kbatch for SplitK"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +using gemm_multi_d_kargs = ck_tile::GemmHostArgs; + +template +float gemm_multi_d(const gemm_multi_d_kargs& kargs, const ck_tile::stream_config& s); diff --git a/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc b/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc new file mode 100644 index 0000000000..a0d7157d03 --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +template +float invoke_gemm_multi_d(const void* a_m_k_dev_buf, + const void* b_k_n_dev_buf, + const std::array& ds_m_n_dev_buf, + void* e_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t StrideA, + ck_tile::index_t StrideB, + const std::array& StrideDs, + ck_tile::index_t StrideE, + int n_warmup, + int n_repeat, + int k_batch) +{ + gemm_multi_d_kargs gemm_descs({a_m_k_dev_buf, + b_k_n_dev_buf, + ds_m_n_dev_buf, + e_m_n_dev_buf, + k_batch, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE}); + + float ave_time = gemm_multi_d( + gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Gemm Multiple-D"}; + static constexpr ck_tile::index_t NumDTensor = DsDataType::size(); + + std::size_t flop = 0, num_btype = 0; + + flop += std::size_t(2) * M * N * K; + + ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) { + num_btype += sizeof(ck_tile::remove_cvref_t>) * M * N; + flop += sizeof(ck_tile::remove_cvref_t>) * M * N; + }); + + num_btype += sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Run Gemm Multiple-D kernel with:\n"; + std::cout << "M =" << M << " N =" << N << " K =" << K << "\n"; + std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE + << "\n"; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << "\n"; + + return ave_time; +} + +template +int run_multiple_d_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + const D0Layout d0_layout = D0Layout{}, + const D1Layout d1_layout = D1Layout{}, + const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + using CDElementWiseFn = MultiplyMultiply; + using DsLayout = ck_tile::tuple; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t StrideA = arg_parser.get_int("stride_a"); + ck_tile::index_t StrideB = arg_parser.get_int("stride_b"); + ck_tile::index_t StrideD = arg_parser.get_int("stride_ds"); + ck_tile::index_t StrideE = arg_parser.get_int("stride_e"); + + ck_tile::index_t StrideD0 = StrideD; + ck_tile::index_t StrideD1 = StrideD; + + const int n_warmup = arg_parser.get_int("warmup"); + const int n_repeat = arg_parser.get_int("repeat"); + const int k_batch = arg_parser.get_int("kbatch"); + + StrideA = get_default_stride(M, K, StrideA, is_row_major(a_layout)); + StrideB = get_default_stride(K, N, StrideB, is_row_major(b_layout)); + StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout)); + StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout)); + StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout)); + + ck_tile::HostTensor a_m_k_tesnor( + host_tensor_descriptor(M, K, StrideA, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n_tensors( + host_tensor_descriptor(K, N, StrideB, is_row_major(b_layout))); + ck_tile::HostTensor d0_m_n_tensors( + host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout))); + ck_tile::HostTensor d1_m_n_tensors( + host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout))); + ck_tile::HostTensor e_m_n_device_result( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_tesnor); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k_tesnor.mData.data()); + b_k_n_dev_buf.ToDevice(b_k_n_tensors.mData.data()); + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + + std::array stridesDs = {StrideD0, StrideD1}; + + invoke_gemm_multi_d(a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + stridesDs, + StrideE, + n_warmup, + n_repeat, + k_batch); + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + + ck_tile::HostTensor e_m_n_host_ref( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_d( + a_m_k_tesnor, b_k_n_tensors, {d0_m_n_tensors, d1_m_n_tensors}, e_m_n_host_ref); + + bool pass{true}; + if(arg_parser.get_int("v")) + { + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + + const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value); + + pass &= ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << std::endl; + std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + return pass; +} + +int run_multiple_d_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string ds_layout = arg_parser.get_str("ds_layout"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(a_layout == "R" && b_layout == "C" && ds_layout == "R") + { + return run_multiple_d_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for provided tensors!"); + } +} diff --git a/example/ck_tile/19_gemm_multi_d/utils.hpp b/example/ck_tile/19_gemm_multi_d/utils.hpp new file mode 100644 index 0000000000..a201d11ffc --- /dev/null +++ b/example/ck_tile/19_gemm_multi_d/utils.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +struct MultiplyMultiply +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index d479cd35f6..f2f39b6e17 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -18,5 +18,6 @@ add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) +add_subdirectory(19_gemm_multi_d) add_subdirectory(35_batched_transpose) add_subdirectory(36_copy) diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 79018b9ced..d2b24ad54e 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -59,6 +59,38 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func, return out_dstr_tensor; } +/** + * @brief Template function that "unpacks" a tuple and applies an element-wise operation. + * + * @param in_element_func Function to apply element-wise. + * @param t Any container containing elements to process, with known size and + * tuple-like semantic. + * @return Calls tile_elementwise_inout with unpacked tuple elements. + */ +template +CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func, + const Tuple& t, + std::index_sequence) +{ + return tile_elementwise_inout(in_element_func, t[number{}]...); +} + +/** + * @brief Template function that "unpacks" a tuple and applies an element-wise operation. + * + * @param in_element_func Function to apply element-wise. + * @param t Any container containing elements to process, with known size and + * tuple-like semantic. + * @return Calls the overloaded function, passing an index sequence. + */ +template +CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func, + const Tuple& t) +{ + static constexpr auto size = Tuple::size(); + return tile_elementwise_inout_unpack(in_element_func, t, std::make_index_sequence{}); +} + template CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value) { diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index fe5077083c..c88deaec01 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -71,6 +71,58 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } +template >> +CK_TILE_HOST void +reference_gemm_multiple_d(const HostTensor& a_m_k, + const HostTensor& b_k_n, + const std::array, DsDataType::size()>& ds_m_n, + HostTensor& c_m_n, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mk_kn_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + for(std::size_t k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_k_n(k, n); + v_acc += + ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + } + + CDataType v_c = 0; + if constexpr(DsDataType::size() == 0) + { + acc_element_op(v_c, ck_tile::type_convert(v_acc)); + } + else if constexpr(DsDataType::size() == 1) + { + acc_element_op(v_c, + ck_tile::type_convert(v_acc), + ck_tile::type_convert(ds_m_n[0](m, n))); + } + else if constexpr(DsDataType::size() == 2) + { + acc_element_op(v_c, + ck_tile::type_convert(v_acc), + ck_tile::type_convert(ds_m_n[0](m, n)), + ck_tile::type_convert(ds_m_n[1](m, n))); + } + c_m_n(m, n) = ck_tile::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency()); +} + template CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); } }; #endif + } // namespace element_wise } // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 6613ceebb2..68e91520bf 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -11,9 +11,12 @@ namespace ck_tile { template ; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; - using CLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; static constexpr index_t kBlockSize = kBlockSize_; static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; @@ -43,6 +49,10 @@ struct CShuffleEpilogueProblem static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; + static constexpr index_t NumDTensor = DsDataType::size(); + + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); }; template @@ -53,10 +63,13 @@ struct CShuffleEpilogue using BDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; - using CLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = Problem::kMPerBlock; @@ -69,7 +82,10 @@ struct CShuffleEpilogue static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr index_t MPerIteration = MPerXdl * MWave; static constexpr index_t NPerIteration = NPerXdl * NWave; + static constexpr index_t NumDTensor = Problem::NumDTensor; + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); /** * @brief Get the vector store size for C tensor. * @@ -83,22 +99,49 @@ struct CShuffleEpilogue CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() { constexpr index_t max_vector_size = 16; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return std::min(static_cast(NPerIteration), static_cast(max_vector_size / sizeof(ODataType))); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return std::min(static_cast(MPerIteration), static_cast(max_vector_size / sizeof(ODataType))); } else { - static_assert(false, "Unsupported CLayout!"); + static_assert(false, "Unsupported ELayout!"); } } + /** + * @brief Get the vector store size for Di tensor. + * + * @return The vector store size for Di tensor. + */ + template + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number index) + { + constexpr index_t max_vector_size = 16; + using DiDataType = remove_cvref_t>; + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return std::min(static_cast(NPerIteration), + static_cast(max_vector_size / sizeof(DiDataType))); + } + else if constexpr(std::is_same_v) + { + return std::min(static_cast(MPerIteration), + static_cast(max_vector_size / sizeof(DiDataType))); + } + else + { + static_assert(false, "Unsupported DLayout!"); + } + return max_vector_size / sizeof(DiDataType); + } /** * @brief Shuffle tile configuration parameters * @@ -116,7 +159,7 @@ struct CShuffleEpilogue else { constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && (kMPerBlock % num_xdl_shuffles == 0), @@ -147,7 +190,8 @@ struct CShuffleEpilogue }(); static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); - using WG = WarpGemmMfmaDispatcher) + if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( make_tuple(number{}, number{}), make_tuple(number{}, number<1>{})); } // M is contiguous dimension - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( make_tuple(number{}, number{}), @@ -177,7 +221,7 @@ struct CShuffleEpilogue } else { - static_assert(false, "Unsupported CLayout!"); + static_assert(false, "Unsupported ELayout!"); } } @@ -202,9 +246,11 @@ struct CShuffleEpilogue return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); } - template - CK_TILE_DEVICE auto - operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) + template + CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, + const OAccTile& o_acc_tile, + const DsDramWindows& ds_dram_windows, + void* p_smem) { constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); @@ -230,7 +276,7 @@ struct CShuffleEpilogue sequence>; constexpr index_t num_access = SFC::get_num_of_access(); - static_assert(std::is_same_v, + static_assert(std::is_same_v, "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); using TileEncodingPattern = @@ -242,6 +288,12 @@ struct CShuffleEpilogue Problem::kNumWaveGroups>; constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + auto d_dram_windows = generate_tuple( + [&](auto idx) { + return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); + }, + number{}); + constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; @@ -265,8 +317,17 @@ struct CShuffleEpilogue store_tile(in_lds_window, c_warptile_in_tensor_casted); block_sync_lds(); - const auto c_out_tensor = - load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); + auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); + + const auto ds_tensor = generate_tuple( + [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); + + const auto c_ds_tiles = concat_tuple_of_reference( + tie(c_out_tensor, c_out_tensor), + generate_tie( + [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + + tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); if constexpr(MemoryOperation == memory_operation_enum::set) { @@ -279,7 +340,13 @@ struct CShuffleEpilogue if constexpr(iAccess != num_access - 1) { constexpr auto step = SFC::get_forward_step(iAccess); + move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); + + static_for<0, NumDTensor, 1>{}([&](auto idx) { + move_tile_window(d_dram_windows[idx], + {step.at(number<0>{}), step.at(number<1>{})}); + }); } }); } diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index d495c0d950..09c7d58558 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -9,7 +9,7 @@ namespace ck_tile { -struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs +struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs { CK_TILE_HOST BatchedGemmHostArgs() = default; CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_, @@ -26,18 +26,28 @@ struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs ck_tile::index_t batch_stride_B_, ck_tile::index_t batch_stride_C_, ck_tile::index_t batch_count_) - : GemmHostArgs( - a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_), + : GemmHostArgs(a_ptr_, + b_ptr_, + {}, + c_ptr_, + k_batch_, + M_, + N_, + K_, + stride_A_, + stride_B_, + {}, + stride_C_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), - batch_stride_C(batch_stride_C_), + batch_stride_E(batch_stride_C_), batch_count(batch_count_) { } ck_tile::index_t batch_stride_A; ck_tile::index_t batch_stride_B; - ck_tile::index_t batch_stride_C; + ck_tile::index_t batch_stride_E; ck_tile::index_t batch_count; }; @@ -46,18 +56,18 @@ struct BatchedGemmKernel : public GemmKernel; - using GemmKernelArgs = typename ck_tile::GemmKernelArgs; + using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>; using ADataType = typename Base::ADataType; using BDataType = typename Base::BDataType; - using CDataType = typename Base::CDataType; + using CDataType = typename Base::EDataType; using TilePartitioner = typename Base::TilePartitioner; using GemmPipeline = typename Base::GemmPipeline; using EpiloguePipeline = typename Base::EpiloguePipeline; using ALayout = typename Base::ALayout; using BLayout = typename Base::BLayout; - using CLayout = typename Base::CLayout; + using CLayout = typename Base::ELayout; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -75,7 +85,7 @@ struct BatchedGemmKernel : public GemmKernel(kargs.b_ptr) + batch_offset_B + splitk_batch_offset.b_k_split_offset; - const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C); - const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); - CDataType* c_ptr = static_cast(kargs.c_ptr) + batch_offset_C; + const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E); + const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E); + CDataType* c_ptr = static_cast(kargs.e_ptr) + batch_offset_C; // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index bfb0d2626b..4cd26c2234 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -16,70 +16,72 @@ namespace ck_tile { -/// @brief The GEMM problem definition. -/// -/// @par Overview -/// This structure defines the GEMM problem configuration by stating all required information -/// like M,N,K sizes and respective strides. -struct GemmProblem -{ - CK_TILE_HOST GemmProblem() = default; - CK_TILE_HOST GemmProblem( - index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) - : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) - { - } - - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; -}; - /// @brief The GEMM kernel host arguments. /// /// @par Overview /// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments /// object. It contain all necessary information required to build proper kernel argument /// and launch kernel on GPU. -struct GemmHostArgs : public GemmProblem +/// This structure defines the GEMM problem configuration by stating all required information +/// like M,N,K sizes and respective strides. +/// NumDTensor describes the number of D tensors. +template +struct GemmHostArgs { CK_TILE_HOST GemmHostArgs() = default; CK_TILE_HOST GemmHostArgs(const void* a_ptr_, const void* b_ptr_, - void* c_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, - index_t stride_C_) - : GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), - a_ptr(a_ptr_), + const std::array& stride_Ds_, + index_t stride_E_) + : a_ptr(a_ptr_), b_ptr(b_ptr_), - c_ptr(c_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), k_batch(k_batch_) { } const void* a_ptr; const void* b_ptr; - void* c_ptr; + const std::array ds_ptr; + void* e_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + const std::array stride_Ds; + index_t stride_E; index_t k_batch; }; /// @brief The GEMM kernel device arguments. +template struct GemmKernelArgs { /// @brief The A input tensor's pointer to device memory. const void* a_ptr; /// @brief The B input tensor's pointer to device memory. const void* b_ptr; - /// @brief The C output tensor's pointer to device memory. - void* c_ptr; + /// @brief The Ds input tensor's pointer to device memory. + const std::array ds_ptr; + /// @brief The E output tensor's pointer to device memory. + void* e_ptr; /// @brief GEMM's M dimension size. index_t M; /// @brief GEMM's N dimension size. @@ -93,8 +95,11 @@ struct GemmKernelArgs /// (in memory) of B tensor. index_t stride_B; /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of C tensor. - index_t stride_C; + /// (in memory) of Ds tensor. + std::array stride_Ds; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of E tensor. + index_t stride_E; index_t k_batch; }; @@ -133,16 +138,19 @@ struct GemmKernelArgs /// @tparam EpiloguePipeline_ The type of class providing the final part of matrix /// multiplication implementation. It is responsible for storing /// results calculated by @ref GemmPipeline_ "GemmPipeline" to -/// the output C tensor in global memory. +/// the output E tensor in global memory. template struct GemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + // TODO: GemmPipeline::CLayout -> GemmPipeline::ELayout will be changed for multi-ABD + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; // Get the persistent kernel if the pipeline has it available @@ -163,11 +171,18 @@ struct GemmKernel using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. - using CDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + + static constexpr index_t NumDTensor = DsDataType::size(); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>{}; + + static_assert(DsLayout::size() == DsDataType::size(), + "The size of DsLayout and DsDataType should be the same"); + using KernelArgs = GemmKernelArgs; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -190,7 +205,7 @@ struct GemmKernel CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { using Kernel = GemmKernel; - const auto kernel = kentry; + const auto kernel = kentry; int occupancy; hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); @@ -200,18 +215,22 @@ struct GemmKernel CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) + CK_TILE_HOST static constexpr KernelArgs + MakeKernelArgs(const GemmHostArgs& hostArgs) { - return GemmKernelArgs{hostArgs.a_ptr, - hostArgs.b_ptr, - hostArgs.c_ptr, - hostArgs.M, - hostArgs.N, - hostArgs.K, - hostArgs.stride_A, - hostArgs.stride_B, - hostArgs.stride_C, - hostArgs.k_batch}; + + return KernelArgs{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_Ds, + hostArgs.stride_E, + hostArgs.k_batch}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -221,8 +240,7 @@ struct GemmKernel struct SplitKBatchOffset { - __device__ SplitKBatchOffset(const GemmKernelArgs& kargs, - const std::size_t k_id = blockIdx.z) + __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); @@ -261,10 +279,10 @@ struct GemmKernel index_t splitted_k; }; - CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) + CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) { if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value) + is_any_of::value) { if(kargs.k_batch != 1) { @@ -360,7 +378,56 @@ struct GemmKernel } } - if constexpr(std::is_same_v) + bool DTesnorIsValid = {true}; + static_for<0, NumDTensor, 1>{}([&](auto index) { + using DiLayout = remove_cvref_t>; + if(std::is_same_v == false) + { + DTesnorIsValid = false; + } + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " + "NPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " + "MPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + }); + + if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { @@ -400,15 +467,17 @@ struct GemmKernel return false; } } - return true; + return DTesnorIsValid; } template - CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - const GemmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) { static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); const auto& a_tensor_view = [&]() { @@ -495,29 +564,54 @@ struct GemmKernel } }(); + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + // TODO: enable vector write for C in ColMajor - const auto& c_tensor_view = [&]() { - if constexpr(std::is_same_v) + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) { return make_naive_tensor_view( - c_ptr, + e_ptr, make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), + make_tuple(kargs.stride_E, 1), number{}, number<1>{}); } else { return make_naive_tensor_view( - c_ptr, + e_ptr, make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), + make_tuple(1, kargs.stride_E), number<1>{}, number<1>{}); } }(); - return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, e_tensor_view); } template @@ -559,35 +653,57 @@ struct GemmKernel } }(); + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + const auto& d_tensor_view = views.at(I2); + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + // TODO vector write in for C in ColMajor - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I2); - if constexpr(std::is_same_v) + const auto& e_pad_view = [&]() { + const auto& e_tensor_view = views.at(I3); + if constexpr(std::is_same_v) { - return pad_tensor_view(c_tensor_view, + return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), sequence{}); } else { - return pad_tensor_view(c_tensor_view, + return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), sequence{}); } }(); - return make_tuple(a_pad_view, b_pad_view, c_pad_view); + return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view); } template CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& c_pad_view = views.at(I2); + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& e_pad_view = views.at(I3); const auto& a_block_window = [&]() { if constexpr(std::is_same_v) @@ -623,12 +739,32 @@ struct GemmKernel } }(); - auto c_block_window = make_tile_window( - c_pad_view, + const auto ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_m}); + } + }, + number{}); + + auto e_block_window = make_tile_window( + e_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - return make_tuple(a_block_window, b_block_window, c_block_window); + return make_tuple(a_block_window, b_block_window, ds_block_window, e_block_window); } /** @@ -636,7 +772,8 @@ struct GemmKernel * * @param a_ptr input A pointer * @param b_ptr input B pointer - * @param c_ptr output C pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer * @param smem_ptr_0 The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. @@ -647,9 +784,10 @@ struct GemmKernel template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, - CDataType* c_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, void* smem_ptr_0, - const GemmKernelArgs& kargs, + const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -657,7 +795,7 @@ struct GemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -668,6 +806,7 @@ struct GemmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); @@ -675,11 +814,11 @@ struct GemmKernel if(UseDefaultScheduler || (get_warp_id() == 0)) { // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); + auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } } @@ -690,7 +829,8 @@ struct GemmKernel * * @param a_ptr input A pointer * @param b_ptr input B pointer - * @param c_ptr output C pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer * @param smem_ptr_0 The starting pointer of 1st shared memory block. * @param smem_ptr_1 The starting pointer of 2nd shared memory block. * @param kargs GEMM kernel arguments @@ -701,10 +841,11 @@ struct GemmKernel */ CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, const BDataType* b_ptr, - CDataType* c_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, void* __restrict__ smem_ptr_0, void* __restrict__ smem_ptr_1, - const GemmKernelArgs& kargs, + const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -712,7 +853,8 @@ struct GemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -722,20 +864,22 @@ struct GemmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); + auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } // Non-persistent kernel entry point template > - CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const + CK_TILE_DEVICE void operator()(KernelArgs kargs) const { const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); @@ -743,12 +887,14 @@ struct GemmKernel const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); + // options const ADataType* a_ptr = static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - CDataType* c_ptr = static_cast(kargs.c_ptr); + + EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; @@ -758,11 +904,12 @@ struct GemmKernel __shared__ char smem_ptr_1[GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { RunGemm2LDS(a_ptr, b_ptr, - c_ptr, + kargs.ds_ptr, + e_ptr, smem_ptr_0, smem_ptr_1, kargs, @@ -775,18 +922,25 @@ struct GemmKernel { if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + RunGemm(a_ptr, + b_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } } } // Persistent kernel entry point template , typename = void> - CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const + CK_TILE_DEVICE void operator()(KernelArgs kargs) const { const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); const auto num_tiles = @@ -809,7 +963,7 @@ struct GemmKernel static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - CDataType* c_ptr = static_cast(kargs.c_ptr); + EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; @@ -820,11 +974,12 @@ struct GemmKernel if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { RunGemm2LDS(a_ptr, b_ptr, - c_ptr, + kargs.ds_ptr, + e_ptr, smem_ptr_0, smem_ptr_1, kargs, @@ -838,9 +993,17 @@ struct GemmKernel if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { - RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + RunGemm(a_ptr, + b_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } } // Advance to the next work item diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index f57600d7a5..533cabb736 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -18,17 +18,17 @@ namespace ck_tile { struct GemmTransKernelArg { - GemmKernelArgs group_karg; + GemmKernelArgs<> group_karg; ck_tile::index_t block_start; ck_tile::index_t block_end; - GemmTransKernelArg() = default; - GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end) + GemmTransKernelArg() = delete; + GemmTransKernelArg(GemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end) : group_karg{karg}, block_start{bl_start}, block_end{bl_end} { } - GemmTransKernelArg(GemmKernelArgs&& karg) : group_karg{karg}, block_start{0}, block_end{0} {} + GemmTransKernelArg(GemmKernelArgs<>&& karg) : group_karg{karg}, block_start{0}, block_end{0} {} }; template @@ -39,7 +39,7 @@ struct GroupedGemmKernel : public GemmKernel; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using ELayout = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -65,8 +65,8 @@ struct GroupedGemmKernel : public GemmKernel& gemm_descs) - -> std::size_t + CK_TILE_HOST static auto + GetWorkSpaceSize(const std::vector>& gemm_descs) -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } @@ -95,7 +95,8 @@ struct GroupedGemmKernel : public GemmKernel& gemm_descs) + CK_TILE_HOST static constexpr auto + GridSize(const std::vector>& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -106,7 +107,8 @@ struct GroupedGemmKernel : public GemmKernel& gemm_descs) + CK_TILE_HOST static auto + MakeKargs(const std::vector>& gemm_descs) -> std::vector { std::vector gemm_kernel_args_; @@ -127,7 +129,7 @@ struct GroupedGemmKernel : public GemmKernel(gemm_descs[i].a_ptr), - type_convert(gemm_descs[i].b_ptr), - type_convert(gemm_descs[i].c_ptr), - M, - N, - K, - stride_a, - stride_b, - stride_c, - gemm_descs[i].k_batch}; + auto karg = GemmKernelArgs<>{type_convert(gemm_descs[i].a_ptr), + type_convert(gemm_descs[i].b_ptr), + {}, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + stride_a, + stride_b, + {}, + stride_e, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -177,7 +181,7 @@ struct GroupedGemmKernel : public GemmKernel& kargs, const tuple& block_idx_2d, const index_t block_idx_z) const { @@ -192,7 +196,7 @@ struct GroupedGemmKernel : public GemmKernel(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - CDataType* c_ptr = static_cast(kargs.c_ptr); + CDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; @@ -204,7 +208,7 @@ struct GroupedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } } @@ -230,7 +234,7 @@ struct GroupedGemmKernel : public GemmKernel& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -238,13 +242,14 @@ struct GroupedGemmKernel : public GemmKernel( - a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); const auto& a_block_window = gemm_tile_windows.at(Base::I0); const auto& b_block_window = gemm_tile_windows.at(Base::I1); + const auto& d_block_window = gemm_tile_windows.at(Base::I2); // Get hot-loop and tail configuration const index_t num_loop = __builtin_amdgcn_readfirstlane( @@ -256,9 +261,10 @@ struct GroupedGemmKernel : public GemmKernel( - c_block_window, c_block_tile, smem_ptr_0); + auto& c_block_window = gemm_tile_windows.at(Base::I3); + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr, diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 8f9d7ac89b..57afb5cbb5 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -2,4 +2,5 @@ add_subdirectory(image_to_column) add_subdirectory(gemm) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) +add_subdirectory(gemm_multi_d) add_subdirectory(data_type) diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index cffa81d1c5..79bd51d65c 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" template class TestCkTileBatchedGemm : public ::testing::Test @@ -23,6 +24,8 @@ class TestCkTileBatchedGemm : public ::testing::Test using BDataType = std::tuple_element_t<4, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; template void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args, @@ -102,9 +105,12 @@ class TestCkTileBatchedGemm : public ::testing::Test using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(args, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index b3146b5f8e..5f2a53645d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -76,12 +76,17 @@ class TestCkTileGemmPipeline : public ::testing::Test using CDataType = std::tuple_element_t<6, Tuple>; static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value; static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value; + + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; + static constexpr bool Persistent = ck_tile::tuple_element_or_default_t::value; // TODO: expose tile size through test t-param ? template - void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + void invoke_gemm(const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 256; @@ -165,9 +170,12 @@ class TestCkTileGemmPipeline : public ::testing::Test using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; args.stride_A = stride_A; args.stride_B = stride_B; - args.stride_C = stride_C; + args.stride_E = stride_C; invoke_gemm(args, ck_tile::stream_config{nullptr, false}); diff --git a/test/ck_tile/gemm_multi_d/CMakeLists.txt b/test/ck_tile/gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000..1ec77eb87a --- /dev/null +++ b/test/ck_tile/gemm_multi_d/CMakeLists.txt @@ -0,0 +1,4 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_ck_tile_gemm_multi_d test_gemm_multi_d.cpp) +endif() diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp new file mode 100644 index 0000000000..a634d825b7 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_d_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // ALayout, BLayout, CLayout, D0Layout, D1Layout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, CDataType, CDElementWiseFn + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F16, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, ElementWiseAddAdd>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F16, ElementWiseAddAdd>, + + std::tuple< Row, Col, Row, Row, Row, F16, F16, F16, F16, F32, F16, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, BF16, BF16, F32, F32, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F32, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F16, F16, F32, F32, F32, F16, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, BF16, BF16, F32, F32, MultiplyMultiply>, + std::tuple< Row, Col, Row, Row, Row, F8, F8, F8, F8, F32, F32, MultiplyMultiply> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiD, KernelTypes); + +#include "test_gemm_multi_d_ut_cases.inc" diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc new file mode 100644 index 0000000000..22d887fa83 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_ut_cases.inc @@ -0,0 +1,334 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch1_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch1_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x256x512) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x768x512) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_512x1280x512) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_256x1280x512) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_768x512x512) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x512x512) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDAddKBatch2_1280x256x512) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x512x512) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x256x512) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x256x512) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x768x512) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_512x1280x512) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_256x1280x512) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_768x512x512) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x512x512) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} + +TYPED_TEST(TestCkTileGemmMultiD, TestCkTileGemmMultiDMultiplyMultiplyKBatch2_1280x256x512) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + this->Run(M, N, K, kBatch); +} diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp new file mode 100644 index 0000000000..7dd91077b1 --- /dev/null +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +struct ElementWiseAddAdd +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) + ck_tile::type_convert(d0) + + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +struct MultiplyMultiply +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +class TestCkTileGemmMultiD : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using D0Layout = std::tuple_element_t<2, Tuple>; + using D1Layout = std::tuple_element_t<3, Tuple>; + using ELayout = std::tuple_element_t<4, Tuple>; + using ADataType = std::tuple_element_t<5, Tuple>; + using BDataType = std::tuple_element_t<6, Tuple>; + using D0DataType = std::tuple_element_t<7, Tuple>; + using D1DataType = std::tuple_element_t<8, Tuple>; + using AccDataType = std::tuple_element_t<9, Tuple>; + using EDataType = std::tuple_element_t<10, Tuple>; + using CDElementWiseFn = std::tuple_element_t<11, Tuple>; + using DsLayout = ck_tile::tuple; + using DsDataType = ck_tile::tuple; + + template + void invoke_gemm_multi_d(const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) + { + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" + << tail_num << "\" which is not supported! PrefetchStages: " + << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + public: + void Run(const int M, + const int N, + const int K, + const int k_batch, + int StrideA = 0, + int StrideB = 0, + int StrideD0 = 0, + int StrideD1 = 0, + int StrideE = 0) + { + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{}); + StrideD1 = f_get_default_stride(M, N, StrideD1, D1Layout{}); + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + + ck_tile::HostTensor a_m_k_tesnor( + f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + ck_tile::HostTensor b_k_n_tensors( + f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + ck_tile::HostTensor d0_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + ck_tile::HostTensor d1_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + ck_tile::HostTensor e_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_tesnor); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k_tesnor.mData.data()); + b_k_n_dev_buf.ToDevice(b_k_n_tensors.mData.data()); + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + std::array stridesDs = {StrideD0, StrideD1}; + + ck_tile::GemmHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + k_batch, + M, + N, + K, + StrideA, + StrideB, + stridesDs, + StrideE}); + + invoke_gemm_multi_d(args, ck_tile::stream_config{nullptr, false}); + + std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideE =" << StrideE + << " StrideD0 =" << StrideD0 << " StrideD1 =" << StrideD1 << std::endl; + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + bool pass = true; + + ck_tile::HostTensor e_m_n_host_ref( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_d( + a_m_k_tesnor, b_k_n_tensors, {d0_m_n_tensors, d1_m_n_tensors}, e_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + K, k_batch, max_accumulated_value); + pass = ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + + EXPECT_TRUE(pass); + } +}; diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 382a32a7d9..54f772f89e 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" template class TestCkTileGroupedGemm : public ::testing::Test @@ -23,6 +24,8 @@ class TestCkTileGroupedGemm : public ::testing::Test using BDataType = std::tuple_element_t<4, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; // Get the persistent value from ck_tile::bool_constant using PersistentType = std::tuple_element_t<7, Tuple>; @@ -48,7 +51,7 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 16; }; - using grouped_gemm_kargs = ck_tile::GemmHostArgs; + using grouped_gemm_kargs = ck_tile::GemmHostArgs; std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); @@ -127,9 +130,12 @@ class TestCkTileGroupedGemm : public ::testing::Test using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblemGetDeviceBuffer(); gemm_descs.push_back( - {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + {p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]}); } ck_tile::DeviceMem gemm_workspace; @@ -442,16 +451,18 @@ class TestCkTileGroupedGemm : public ::testing::Test const bool splitk = gemm_descs[0].k_batch > 1; for(const auto& arg : gemm_descs) { - kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.c_ptr, - arg.M, - arg.N, - arg.K, - arg.stride_A, - arg.stride_B, - arg.stride_C, - arg.k_batch}); + kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr, + arg.b_ptr, + {}, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.stride_A, + arg.stride_B, + {}, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, false, 1}; ck_tile::hip_check_error(