diff --git a/CHANGELOG.md b/CHANGELOG.md index 38669385f3..dafe1b5c87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM * Added support for Multiple D GEMM +* Added support for Multiple ABD GEMM * Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. diff --git a/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt b/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..f382e0cf45 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt @@ -0,0 +1 @@ +add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp) diff --git a/example/ck_tile/22_gemm_multi_abd/README.md b/example/ck_tile/22_gemm_multi_abd/README.md new file mode 100644 index 0000000000..c272df3fb5 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/README.md @@ -0,0 +1,35 @@ +#Multiple ABD GEMM + +This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation. + +## build +``` +#in the root of ck_tile +mkdir build && cd build +#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \ + leave it blank +sh ../script/cmake-ck-dev.sh ../ +#The basic pipeline method on the gemm calculation +make tile_example_gemm_multi_abd_fp16 -j +``` +This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16` + +## example +``` +args: + -m M dimensions - (Default: 3840) + -n N dimensions - (Default: 4096) + -k K dimensions - (Default: 4096) +-as_layout Tensor A layout (default:R) +-bs_layout Tensor B layout (default:C) +-ds_layout Tensor D layout (default:R) +-e_layout Tensor E layout (default:R) +-stride_as Tensor A strides - (Default: 0) +-stride_bs Tensor B strides - (Default: 0) +-stride_e Tensor C strides - (Default: 0) +-stride_ds Tensor D strides - (Default: 0) +-validate 0. No validation, 1. Validation on GPU. (Default: 1) + -warmup Number of iterations before benchmark the kernel. (Default: 10) + -repeat Number of iterations to benchmark the kernel. (Default: 100) + -kbatch kbatch for SplitK. (Default: 1) +``` \ No newline at end of file diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp new file mode 100644 index 0000000000..6d955c3a09 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "gemm_multi_abd_fp16.hpp" +#include "utils.hpp" + +template +auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_config& s) -> float +{ + constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; + + constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; + + constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; + + constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + return ave_time; +} + +#include "run_gemm_multi_abd_fp16_example.inc" + +int main(int argc, char* argv[]) +{ +#if CK_TILE_USE_WMMA + return !run_multiple_abd_gemm_example(argc, argv); +#else + return !run_multiple_abd_gemm_example(argc, argv); +#endif +} diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp new file mode 100644 index 0000000000..35bc232eca --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +using A0DataType = ck_tile::half_t; +using A1DataType = ck_tile::half_t; + +using B0DataType = ck_tile::half_t; +using B1DataType = ck_tile::half_t; + +using D0DataType = ck_tile::half_t; +using D1DataType = ck_tile::half_t; + +using EDataType = ck_tile::half_t; + +using AsDataType = ck_tile::tuple; +using BsDataType = ck_tile::tuple; +using DsDataType = ck_tile::tuple; + +using AccDataType = float; + +struct GemmConfigMemory +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 8; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +struct GemmConfigV3 +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigV4 +{ + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigV3_Wmma +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "4096", "k dimension") + .insert("as_layout", "R", "As tensor data layout - Row by default") + .insert("bs_layout", "C", "Bs tensor data layout - Col by default") + .insert("ds_layout", "R", "Ds tensor data layout - Row by default") + .insert("e_layout", "R", "E tensor data layout - Row by default") + .insert("stride_as", "0", "Tensor A stride") + .insert("stride_bs", "0", "Tensor B stride") + .insert("stride_ds", "0", "Tensor Ds stride") + .insert("stride_e", "0", "Tensor E stride") + .insert("v", "1", "0. No validation, 1. Validation on GPU") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("kbatch", "1", "kbatch for SplitK"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} +using gemm_multi_abd_kargs = + ck_tile::GemmMultiABDHostArgs; + +template +float gemm_multi_abd(const gemm_multi_abd_kargs& kargs, const ck_tile::stream_config& s); diff --git a/example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc b/example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc new file mode 100644 index 0000000000..881961c9db --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +template +float invoke_gemm_multi_abd(const std::array& as_m_k_dev_buf, + const std::array& bs_k_n_dev_buf, + const std::array& ds_m_n_dev_buf, + void* e_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + const std::array& StrideAs, + const std::array& StrideBs, + const std::array& StrideDs, + ck_tile::index_t StrideE, + int n_warmup, + int n_repeat, + int k_batch) +{ + gemm_multi_abd_kargs gemm_descs({as_m_k_dev_buf, + bs_k_n_dev_buf, + ds_m_n_dev_buf, + e_m_n_dev_buf, + k_batch, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE}); + + float ave_time = gemm_multi_abd( + gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Gemm Multiple-ABD"}; + + std::size_t flop = 0, num_btype = 0; + + flop += std::size_t(2) * M * N * K; + + num_btype += + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Run Gemm Multiple-ABD kernel with:\n"; + std::cout << "M =" << M << " N =" << N << " K =" << K << "\n"; + std::cout << "StrideA = " << StrideAs[0] << " StrideB = " << StrideBs[0] + << " StrideE = " << StrideE << "\n"; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << "\n"; + + return ave_time; +} + +template +int run_gemm_multi_abd_example_with_layouts(int argc, + char* argv[], + const A0Layout a0_layout = A0Layout{}, + const A1Layout a1_layout = A1Layout{}, + const B0Layout b0_layout = B0Layout{}, + const B1Layout b1_layout = B1Layout{}, + const D0Layout d0_layout = D0Layout{}, + const D1Layout d1_layout = D1Layout{}, + const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + using AElementWiseFn = ck_tile::element_wise::AddScale; + using BElementWiseFn = ck_tile::element_wise::AddScale; + using CDEElementWiseFn = ck_tile::element_wise::MultiDMultiply; + using AsLayout = ck_tile::tuple; + using BsLayout = ck_tile::tuple; + using DsLayout = ck_tile::tuple; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t StrideA = arg_parser.get_int("stride_as"); + ck_tile::index_t StrideB = arg_parser.get_int("stride_bs"); + ck_tile::index_t StrideD = arg_parser.get_int("stride_ds"); + ck_tile::index_t StrideE = arg_parser.get_int("stride_e"); + + ck_tile::index_t StrideA0 = StrideA; + ck_tile::index_t StrideA1 = StrideA; + + ck_tile::index_t StrideB0 = StrideB; + ck_tile::index_t StrideB1 = StrideB; + + ck_tile::index_t StrideD0 = StrideD; + ck_tile::index_t StrideD1 = StrideD; + + const int n_warmup = arg_parser.get_int("warmup"); + const int n_repeat = arg_parser.get_int("repeat"); + const int k_batch = arg_parser.get_int("kbatch"); + + StrideA0 = get_default_stride(M, N, StrideA0, is_row_major(a1_layout)); + StrideA1 = get_default_stride(M, N, StrideA1, is_row_major(a1_layout)); + + StrideB0 = get_default_stride(K, N, StrideB0, is_row_major(b0_layout)); + StrideB1 = get_default_stride(K, N, StrideB1, is_row_major(b1_layout)); + + StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout)); + StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout)); + + StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout)); + + ck_tile::HostTensor a0_m_k_tesnor( + host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout))); + ck_tile::HostTensor a1_m_k_tesnor( + host_tensor_descriptor(M, K, StrideA1, is_row_major(a1_layout))); + + ck_tile::HostTensor b0_k_n_tensors( + host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout))); + ck_tile::HostTensor b1_k_n_tensors( + host_tensor_descriptor(K, N, StrideB1, is_row_major(b1_layout))); + + ck_tile::HostTensor d0_m_n_tensors( + host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout))); + ck_tile::HostTensor d1_m_n_tensors( + host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout))); + + ck_tile::HostTensor e_m_n_device_result( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a0_m_k_tesnor); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a1_m_k_tesnor); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(b0_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b1_k_n_tensors); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data()); + a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data()); + + b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data()); + b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data()); + + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(), + a1_m_k_dev_buf.GetDeviceBuffer()}; + + std::array bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(), + b1_k_n_dev_buf.GetDeviceBuffer()}; + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + + std::array strideAs = {StrideA0, StrideA1}; + std::array strideBs = {StrideB0, StrideB1}; + std::array strideDs = {StrideD0, StrideD1}; + + invoke_gemm_multi_abd(as_ptr_buf, + bs_ptr_buf, + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + strideAs, + strideBs, + strideDs, + StrideE, + n_warmup, + n_repeat, + k_batch); + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + + ck_tile::HostTensor a_m_k_host_ref_element_result( + host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout))); + ck_tile::HostTensor b_k_n_host_ref_element_result( + host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout))); + ck_tile::HostTensor e_m_n_host_ref( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + a_m_k_host_ref_element_result.SetZero(); + b_k_n_host_ref_element_result.SetZero(); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_abd({a0_m_k_tesnor, a1_m_k_tesnor}, + {b0_k_n_tensors, b1_k_n_tensors}, + {d0_m_n_tensors, d1_m_n_tensors}, + a_m_k_host_ref_element_result, + b_k_n_host_ref_element_result, + e_m_n_host_ref); + + bool pass{true}; + if(arg_parser.get_int("v")) + { + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + + const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value); + + pass &= ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << std::endl; + std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + return pass; +} + +template +int run_multiple_abd_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string as_layout = arg_parser.get_str("as_layout"); + const std::string bs_layout = arg_parser.get_str("bs_layout"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(as_layout == "R" && bs_layout == "C") + { + return run_gemm_multi_abd_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/example/ck_tile/22_gemm_multi_abd/utils.hpp b/example/ck_tile/22_gemm_multi_abd/utils.hpp new file mode 100644 index 0000000000..38bf8623d4 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/utils.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 8fce70ba04..75d32a5eb0 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(18_flatmm) add_subdirectory(19_gemm_multi_d) add_subdirectory(20_grouped_convolution) add_subdirectory(21_elementwise) +add_subdirectory(22_gemm_multi_abd) add_subdirectory(35_batched_transpose) add_subdirectory(38_block_scale_gemm) add_subdirectory(39_copy) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 8b7541bf23..c7c4702e22 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -26,6 +26,29 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, return tile_window.load(number{}, bool_constant{}); } +/** + * @brief Load tile with elementwise function + * + * @note This function is a modification of the existing load function. + * It has been extended with two additional parameters: it takes a tuple as input + * and an elementwise function. For each A = A0, A1… AN, the elementwise function + * is additionally applied during a single read. + */ +template +CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) +{ + // TODO: Tile windows should works with unknow number of params + // Load element_wise API works only when the input typle is a tuple-tyupe + return tile_window[number<0>{}].load( + tile_window, elementwise, number{}, bool_constant{}); +} + template + CK_TILE_DEVICE auto load(const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) const + { + constexpr auto tile_dstr = typename Base::TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + load(dst_tensor, + tile_window, + elementwise, + number{}, + bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) const + { + + using Traits = typename Base::Traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = typename Base::TileDstr{}; + constexpr auto sizeOfTuple = TileWindow_::size(); + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = + tile_window[number<0>{}].pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = + tile_window[number<0>{}].pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from bottom tensor + const auto idx_vec_value = generate_tuple( + [&](auto jj) { + return tile_window[number{}] + .get_bottom_tensor_view() + .template get_vectorized_elements( + bottom_tensor_thread_coord, + 0, + bool_constant{}); + }, + number{}); + + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + ck_tile::apply( + [&](auto&&... t) { + elementwise(dst_tensor.get_thread_buffer().template at(), + t.template get_as< + typename Base::DataType>()[j / Traits::PackedSize]...); + }, + idx_vec_value); + }); + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + template @@ -857,6 +967,39 @@ CK_TILE_DEVICE void move_tile_window( window.move(step); } +template +CK_TILE_DEVICE void move_tile_window( + tuple>& window, + const typename tile_window_with_static_distribution::BottomTensorIndex& step) +{ + using T = tuple>; + + static constexpr auto N = T::size(); + static_for<0, N, 1>{}([&](auto Is) { window[number{}].move(step); }); +} + +template ::value>* = nullptr> +CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step) +{ + static constexpr auto N = TileWindowWithStaticDistributionType::size(); + static_for<0, N, 1>{}([&](auto Is) { window[number{}].move(step); }); +} + /** * @brief This class provides description of tile windowed view on the device memory. * diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index caa00e5994..d9379b4420 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -261,6 +261,81 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } +template >, + typename BDataType = remove_cvref_t>, + typename DDataType = remove_cvref_t>> +CK_TILE_HOST void +reference_gemm_multiple_abd(const std::array, AsDataType::size()>& as_m_k, + const std::array, BsDataType::size()>& bs_k_n, + const std::array, DsDataType::size()>& ds_m_n, + HostTensor& a_m_k, + HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const CDElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto as_m_k_tuple = + generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number{}); + + auto bs_k_n_tuple = + generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number{}); + + auto ds_m_n_tuple = + generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number{}); + + // Apply elementwise function to A + auto a_elementwise_fn = [&](auto i, auto j) { + ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple); + }; + + make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency()); + + // Apply elementwise function to B + auto b_elementwise_fn = [&](auto i, auto j) { + ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple); + }; + + make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency()); + + auto f_mk_kn_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + for(std::size_t k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_k_n(k, n); + v_acc += + ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + } + + CDataType v_c = 0; + + ck_tile::apply( + [&](auto&&... t) { + acc_element_op(v_c, + ck_tile::type_convert(v_acc), + ck_tile::type_convert(t(m, n))...); + }, + ds_m_n_tuple); + + c_m_n(m, n) = ck_tile::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency()); +} + template + CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const + { + // Start with the base value c + float result = ck_tile::type_convert(0.0f); + + // Add by each D parameter using fold expression + ((result += ck_tile::type_convert(as)), ...); + + a = ck_tile::type_convert(scale * result); + } + + float scale = 1.0; +}; + struct MultiDMultiply { template diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 628af0e0b3..ebd97c1c66 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -28,8 +28,8 @@ struct GetDataType using type = typename T::DataType; // Use T::ScaleN::DataType }; -template struct CShuffleEpilogueProblem { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; @@ -83,12 +83,27 @@ template struct CShuffleEpilogue { using Problem = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; + + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + using ATypeToUse = std::conditional_t, BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 54becd3c0f..2843966cd7 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -28,8 +28,8 @@ struct Default2DEpilogueProblem static constexpr index_t NumDTensor = 0; }; -template { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CLayout = remove_cvref_t; using DsDataType = remove_cvref_t; using CDElementwise = remove_cvref_t; @@ -157,14 +157,28 @@ struct Default2DEpilogue template struct DefaultGemm2DEpilogue : public Default2DEpilogue { - using Problem = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; + using Problem = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; + using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index de13e305e0..6e07dbc00e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,6 +31,7 @@ #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index fcfbf9635f..588d903b25 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -90,10 +90,10 @@ struct BatchedGemmKernel !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. static_assert(!is_detected::value && !is_detected::value, - "C/ELayout and C/EDataType must be scalars."); + "C/CLayout and C/EDataType must be scalars."); struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<> { diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index e37b4f36d4..d632b1596c 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -89,7 +89,7 @@ struct GemmKernel /// @brief Specify the layout configurations for A, B, E and D using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using CLayout = remove_cvref_t; /// @brief Specify the data type configurations for A, B, E and D using ADataType = remove_cvref_t; @@ -106,10 +106,10 @@ struct GemmKernel !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. - static_assert(!is_detected::value && + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && !is_detected::value, - "C/ELayout and C/EDataType must be scalars."); + "C/CLayout and C/EDataType must be scalars."); static constexpr index_t NumATensor = 1; static constexpr index_t NumBTensor = 1; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp new file mode 100644 index 0000000000..3b050e03ed --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The MultiABD GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GemmKernelMultiABD "GemmKernelMultiABD" when creating +/// kernel arguments object. It contain all necessary information required to build proper +/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem +/// configuration by stating all required information like M,N,K sizes and respective strides. +/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required). +/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required). +/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not +/// required). +template +struct GemmMultiABDHostArgs +{ + CK_TILE_HOST GemmMultiABDHostArgs(const std::array& as_ptr_, + const std::array& bs_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + const std::array& stride_As_, + const std::array& stride_Bs_, + const std::array& stride_Ds_, + index_t stride_E_) + : as_ptr(as_ptr_), + bs_ptr(bs_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_As(stride_As_), + stride_Bs(stride_Bs_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const std::array as_ptr; + const std::array bs_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + index_t M; + index_t N; + index_t K; + const std::array stride_As; + const std::array stride_Bs; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +template +struct GemmKernelMultiABD +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + /// @brief Specify the layout configurations for A, B, E and D + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, E and D + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be a tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value, + "ALayout and ADataType must be a tuple."); + + /// @brief BLayout and BDataType are expected to be a tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value, + "BLayout and BDataType must be a tuple."); + + /// @brief CLayout and EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "CLayout and EDataType must be a scalar."); + + /// @brief DsLayout and DsDataType are expected to be tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value && + DsLayout::size() == DsDataType::size() && DsLayout::size() > 0, + "DsLayout and DsDataType must be tuples and must have the same size."); + + /// @brief The sizes of NumATensor, NumBTensor and NumDTensor is set by the user." + static constexpr index_t NumATensor = AsDataType::size(); + static constexpr index_t NumBTensor = BsDataType::size(); + static constexpr index_t NumDTensor = DsDataType::size(); + + CK_TILE_HOST static auto GetName() -> const std::string + { + return UniversalGemmKernel::GetName(); + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 + { + return UniversalGemmKernel::GridSize(M, N, KBatch); + } + + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + return UniversalGemmKernel::MaxOccupancyGridSize(s); + } + + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return UniversalGemmKernel::BlockSize(); + } + + CK_TILE_HOST static constexpr auto + MakeKernelArgs(const GemmMultiABDHostArgs& hostArgs) -> + typename UniversalGemmKernel::KernelArgs + { + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B, and D. + return UniversalGemmKernel::MakeKernelArgs( + UniversalGemmHostArgs(hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, + hostArgs.stride_E)); + } + + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + { + // Currently MultiABD kernel doesn't support k_batch > 1 + if(kargs.k_batch > 1) + { + return false; + } + + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void + { + UniversalGemmKernel{}.template operator()(kargs); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp index 9d3ac8b901..b0b2905cb4 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -95,7 +95,7 @@ struct GemmKernelMultiD /// @brief Specify the layout configurations for A, B, E and D using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using CLayout = remove_cvref_t; using DsLayout = remove_cvref_t; /// @brief Specify the data type configurations for A, B, E and D @@ -114,10 +114,10 @@ struct GemmKernelMultiD !is_detected::value, "BLayout and BDataType must be scalars."); - /// @brief ELayout and EDataType are expected to be scalars, not a tuple. - static_assert(!is_detected::value && + /// @brief CLayout and EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && !is_detected::value, - "ELayout and EDataType must be scalars."); + "CLayout and EDataType must be scalars."); /// @brief DsLayout and DsDataType are expected to be tuple, not a scalar. static_assert(is_detected::value && diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index e38e49f5d1..df1d6c9e4f 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -120,10 +120,10 @@ struct GroupedGemmKernel !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. static_assert(!is_detected::value && !is_detected::value, - "C/ELayout and C/EDataType must be scalars."); + "C/CLayout and C/EDataType must be scalars."); using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; using Kernel = GroupedGemmKernel; @@ -364,12 +364,8 @@ struct GroupedGemmKernel const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0); + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(Base::I3); EpiloguePipeline{}.template diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index cfba8b6c9d..8f44108cc4 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -157,23 +157,23 @@ struct UniversalGemmKernel using EpiloguePipeline = remove_cvref_t; static constexpr bool ADataTypeIsTuple = - is_detected::value; + is_detected::value; static constexpr bool BDataTypeIsTuple = - is_detected::value; + is_detected::value; static constexpr bool DDataTypeIsTuple = is_detected::value; static constexpr bool ALayoutIsTuple = - is_detected::value; + is_detected::value; static constexpr bool BLayoutIsTuple = - is_detected::value; + is_detected::value; static constexpr bool DLayoutIsTuple = is_detected::value; using AsLayout = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using BsLayout = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using DsLayout = std::conditional_t>>; using AsDataType = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using BsDataType = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using DsDataType = @@ -193,9 +193,12 @@ struct UniversalGemmKernel remove_cvref_t, remove_cvref_t>>; - using ELayout = remove_cvref_t; + using CLayout = remove_cvref_t; using EDataType = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; // Get the persistent kernel if the pipeline has it available @@ -483,7 +486,7 @@ struct UniversalGemmKernel bool DTesnorIsValid = {true}; static_for<0, NumDTensor, 1>{}([&](auto index) { using DiLayout = remove_cvref_t>; - if(std::is_same_v == false) + if(std::is_same_v == false) { DTesnorIsValid = false; } @@ -529,7 +532,7 @@ struct UniversalGemmKernel } }); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { @@ -724,7 +727,7 @@ struct UniversalGemmKernel // TODO: enable vector write for C in ColMajor const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return make_naive_tensor_view( e_ptr, @@ -818,7 +821,7 @@ struct UniversalGemmKernel // TODO vector write in for C in ColMajor const auto& e_pad_view = [&]() { const auto& e_tensor_view = views.at(I3); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, make_tuple(number{}, @@ -975,8 +978,8 @@ struct UniversalGemmKernel const auto& bs_block_window = gemm_tile_windows.at(I1); const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = - GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0); + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); if(UseDefaultScheduler || (get_warp_id() == 0)) { @@ -1031,8 +1034,13 @@ struct UniversalGemmKernel const auto& bs_block_window = gemm_tile_windows.at(I1); const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}( - as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1); + const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window, + AElementWise{}, + bs_block_window, + BElementWise{}, + num_loop, + smem_ptr_0, + smem_ptr_1); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 2bee550b3c..b5584f98df 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -11,12 +11,17 @@ namespace ck_tile { template struct GemmPipelineAgBgCrImplBase { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using ADataType = remove_cvref_t{}, AsDataType>>; + using ALayout = remove_cvref_t{}, AsLayout>>; + using BDataType = remove_cvref_t{}, BsDataType>>; + using BLayout = remove_cvref_t{}, BsLayout>>; + static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; @@ -57,6 +62,13 @@ struct GemmPipelineAgBgCrImplBase store_tile(lds_tile_window, block_tile_tmp); } + template + CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, + const SrcBlockTile& src_block_tile) const + { + store_tile(lds_tile_window, src_block_tile); + } + template CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile, const SrcTileWindow& lds_tile_window, @@ -88,23 +100,100 @@ struct GemmPipelineAgBgCrImplBase return make_tuple(std::move(a_lds_block), std::move(b_lds_block)); } + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_col_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp[number{}].get_window_origin() + offset, + Policy::template MakeADramTileDistribution()); + }, + number{}); + return std::move(a_copy_dram_window); + } + + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_col_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp.get_window_origin() + offset, + Policy::template MakeADramTileDistribution()); + + return std::move(a_copy_dram_window); + } + + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_row_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp[number{}].get_window_origin() + offset, + Policy::template MakeBDramTileDistribution()); + }, + number{}); + return std::move(a_copy_dram_window); + } + + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_row_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp.get_window_origin() + offset, + Policy::template MakeBDramTileDistribution()); + + return std::move(a_copy_dram_window); + } + template CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, const ALdsTensorView& a_lds_block_view, const ALdsLoadTileDistr&, const array& offset = {0, 0}) const { - constexpr bool is_col_major = std::is_same_v; - - using YPerTile = std::conditional_t, number>; - using XPerTile = std::conditional_t, number>; - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile{}, XPerTile{}), - a_dram_block_window_tmp.get_window_origin() + offset, - Policy::template MakeADramTileDistribution()); + auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); // A LDS tile window for store auto a_lds_shape = []() { @@ -138,16 +227,8 @@ struct GemmPipelineAgBgCrImplBase const BLdsLoadTileDistr&, const array& offset = {0, 0}) const { - constexpr bool is_row_major = std::is_same_v; - - using YPerTile = std::conditional_t, number>; - using XPerTile = std::conditional_t, number>; - - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile{}, XPerTile{}), - b_dram_block_window_tmp.get_window_origin() + offset, - Policy::template MakeBDramTileDistribution()); + // A DRAM tile window for load + auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); // TODO: Do we really need those two tile windows??? // They're exactly same... diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 5f4ee8987e..7159eda683 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -107,14 +107,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using Base = BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockGemm = remove_cvref_t())>; using I0 = number<0>; @@ -386,17 +395,25 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); - - using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); - using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); - - ABlockTile a_block_tile; - BBlockTile b_block_tile; - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; @@ -470,45 +476,61 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // ----------------------------------------------------------------------------------------- // Gemm pipeline start - - // prefetch - // global read 0 - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + // LDS write 0 if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // global read 1 + + elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); __builtin_amdgcn_sched_barrier(0); @@ -520,38 +542,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major && !is_a_load_tr_v()) + if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - if constexpr(is_b_row_major && !is_b_load_tr_v()) + if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + elementwise_Bs_res = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -574,27 +600,26 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); @@ -602,13 +627,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const @@ -628,9 +656,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 * @note This is used by the persistent gemm kernel variants that don't determine * hot loop and tail number on the host side, e.g. grouped gemm kernel. */ - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, @@ -639,7 +671,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; constexpr auto tail_num = tail_num_.value; - constexpr auto PassThrough = [](const auto& x) { return x; }; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, PassThrough, @@ -658,20 +690,97 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 * @note This is used by the kernel variants that are able to determine * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. */ - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const BDataType& b) { e = b; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } + + /** + * @brief Quant operator(), single input: This function runs the pipeline by wrapping it with + * the tail handler. + * + * @note This is used by the persistent gemm kernel variants that don't determine + * hot loop and tail number on the host side, e.g. grouped gemm kernel. + */ + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem); + } + + /** + * @brief Quant operator(), single input: This function runs the pipeline using compile-time + * known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index c835809b5d..b362f751c6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -97,11 +97,24 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 using Base = BaseGemmPipelineAgBgCrCompV4; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + static_assert(!std::is_same_v, "Not implemented"); static constexpr index_t APackedSize = @@ -109,10 +122,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; using I0 = number<0>; using I1 = number<1>; @@ -244,18 +253,26 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* __restrict__ p_smem_0, void* __restrict__ p_smem_1) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), "B block window has incorrect lengths for defined BLayout!"); - ////////////// global window & register ///////////////// - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); - - // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); - - // A register tile for global load - constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution(); - constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution(); - using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr)); - using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr)); - ABlockTile a_global_load_tile; - BBlockTile b_global_load_tile; - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; @@ -312,8 +306,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // global prefetch 0 // global read 0 - Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + ////////////// LDS desc, window & register ///////////////// auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); @@ -343,34 +336,75 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // Generating a tuple with tile_windows for values A0, A1, ... AN + auto a_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); + + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + // Generating a tuple with tile_windows for values B0, B1, ... BN + auto b_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_tile_windows, b_dram_tile_window_step); + // LDS write 0 if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res); } // global read 1 - Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); block_sync_lds(); constexpr auto ALdsTileDistr = @@ -423,27 +457,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res); } - Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); if(HasHotLoop) { @@ -461,31 +500,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp); } else { - Base::LocalPrefill( - a_copy_lds_window0, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp); } else { - Base::LocalPrefill( - b_copy_lds_window0, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res); } - Base::GlobalPrefetch( - a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = + load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = + load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); // gemm block_gemm(c_block_tile, a_block_tile0, b_block_tile0); HotLoopScheduler(); @@ -501,32 +541,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp); } else { - Base::LocalPrefill( - a_copy_lds_window1, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp); } else { - Base::LocalPrefill( - b_copy_lds_window1, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res); } block_sync_lds(); - Base::GlobalPrefetch( - a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = + load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = + load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); + // gemm block_gemm(c_block_tile, a_block_tile1, b_block_tile1); HotLoopScheduler(); @@ -548,23 +590,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res); } block_gemm(c_block_tile, a_block_tile0, b_block_tile0); } @@ -606,13 +648,17 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem_0, @@ -628,27 +674,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 p_smem_1); } - public: - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, void* __restrict__ p_smem_0, void* __restrict__ p_smem_1) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const BDataType& b) { e = b; }, num_loop, p_smem_0, p_smem_1); } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, @@ -658,7 +711,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; constexpr auto tail_num = tail_num_.value; - constexpr auto PassThrough = [](const auto& x) { return x; }; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, PassThrough, @@ -670,5 +723,69 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0, + void* p_smem_1) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem_0, + p_smem_1); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem_0, + p_smem_1); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index b83d37a790..474d1a5a21 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -41,15 +41,24 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 using Base = BaseGemmPipelineAgBgCrCompV5; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; @@ -121,17 +130,25 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BsDramBlockWindowTmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* __restrict__ p_smem_0) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v BGemmTile b_tile_0, b_tile_1; // Register tile for A and B. - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); - ABlockTile a_global_load_tile; - BBlockTile b_global_load_tile; + ABlockTile elementwise_As_res; + BBlockTile elementwise_Bs_res; // Block GEMM auto block_gemm = BlockGemm(); @@ -248,33 +267,45 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 // define ping, pong steps here as lambda functions. auto MemoryOpsStep = [&](auto idx) { // Memory read half here. - Base::GlobalPrefetch( - a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each + // A0, A1, … AN. The values A0, A1, … AN are read by the same thread. + elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each + // B0, B1, … BN. The values B0, B1, … BN are read by the same thread. + elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } if(idx == 0) @@ -351,13 +382,17 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem_0) const @@ -371,21 +406,62 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 p_smem_0); } - public: - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, void* __restrict__ p_smem_0) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const BDataType& b) { e = b; }, num_loop, p_smem_0); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem_0); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem_0); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index e1acfebc47..9e522d4364 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -157,14 +157,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using Base = BaseGemmPipelineAgBgCrMem; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockGemm = remove_cvref_t())>; @@ -236,17 +245,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); @@ -334,10 +353,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // prefetch // global read 0 - Base::GlobalPrefetch( - a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -348,32 +378,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{})); } // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); // main body @@ -397,14 +430,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( a_shuffle_tmp, a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill( a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - a_element_func); + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { @@ -413,22 +445,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( b_shuffle_tmp, b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill( b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - b_element_func); + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); i += PrefetchStages; @@ -450,26 +483,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); + a_block_tiles.get(number{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + b_block_tiles.get(number{})); } }); @@ -526,17 +557,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); @@ -623,10 +664,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // prefetch // global read 0 - Base::GlobalPrefetch( - a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -637,32 +690,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{})); } // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); // main body @@ -687,14 +743,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( a_shuffle_tmp, a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill( a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - a_element_func); + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { @@ -703,22 +758,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( b_shuffle_tmp, b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill( b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - b_element_func); + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); i += PrefetchStages; @@ -740,26 +797,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); + a_block_tiles.get(number{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + b_block_tiles.get(number{})); } }); @@ -813,13 +868,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const @@ -833,9 +891,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem p_smem); } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, @@ -844,7 +906,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; constexpr auto tail_num = tail_num_.value; - constexpr auto PassThrough = [](const auto& x) { return x; }; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, PassThrough, @@ -856,20 +918,82 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const ADataType& a) { e = a; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index e3b4863392..eb363d59b8 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -15,14 +15,23 @@ namespace ck_tile { template struct GemmPipelineAGmemBGmemCRegV1 { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockGemm = remove_cvref_t())>; @@ -81,17 +90,25 @@ struct GemmPipelineAGmemBGmemCRegV1 return Policy::template GetSmemSize(); } - template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v>, @@ -133,22 +150,30 @@ struct GemmPipelineAGmemBGmemCRegV1 auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + auto as_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); // A LDS tile window for store auto a_copy_lds_window = make_tile_window( a_lds_block, make_tuple(number{}, number{}), {0, 0}); // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + auto bs_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); // B LDS tile window for store auto b_copy_lds_window = make_tile_window( @@ -182,13 +207,22 @@ struct GemmPipelineAGmemBGmemCRegV1 // prefetch // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - auto b_block_tile = load_tile(b_copy_dram_window); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); { // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -198,13 +232,12 @@ struct GemmPipelineAGmemBGmemCRegV1 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); - store_tile(a_copy_lds_window, a_block_tile_tmp); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp); } else { - store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + store_tile(a_copy_lds_window, elementwise_As_res); } // LDS write 0 @@ -212,13 +245,12 @@ struct GemmPipelineAGmemBGmemCRegV1 { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); - store_tile(b_copy_lds_window, b_block_tile_tmp); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp); } else { - store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile)); + store_tile(b_copy_lds_window, elementwise_Bs_res); } } @@ -226,8 +258,8 @@ struct GemmPipelineAGmemBGmemCRegV1 while(iCounter > 0) { // global read i + 1 - a_block_tile = load_tile(a_copy_dram_window); - b_block_tile = load_tile(b_copy_dram_window); + elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); + elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); block_sync_lds(); @@ -237,22 +269,20 @@ struct GemmPipelineAGmemBGmemCRegV1 block_sync_lds(); // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // LDS write i + 1 if constexpr(is_a_col_major) { auto a_shuffle_tmp_loop = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp_loop, a_block_tile); - store_tile(a_copy_lds_window, - tile_elementwise_in(a_element_func, a_shuffle_tmp_loop)); + transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp_loop); } else { - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); } // LDS write i + 1 @@ -260,14 +290,12 @@ struct GemmPipelineAGmemBGmemCRegV1 { auto b_shuffle_tmp_loop = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp_loop, b_block_tile); - store_tile(b_copy_lds_window, - tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); + transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp_loop); } else { - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); } iCounter--; @@ -284,20 +312,40 @@ struct GemmPipelineAGmemBGmemCRegV1 return c_block_tile; } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType & b) { return b; }, + [](auto& e, const BDataType & b) { e = b; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index b151cd6782..c309f8908a 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -15,30 +15,66 @@ namespace ck_tile { template struct GemmPipelineAGmemBGmemCRegV2 { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; - static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; + template + static constexpr index_t GetVectorSizeA() + { + return Problem::VectorSizeA; + } + template + static constexpr index_t GetVectorSizeB() + { + return Problem::VectorSizeB; + } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool Preshuffle = Problem::Preshuffle; + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off return concat('_', "pipeline_AGmemBGmemCRegV2", - concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize)); + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize)); // clang-format on } CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } @@ -56,17 +92,31 @@ struct GemmPipelineAGmemBGmemCRegV2 BPackedSize; } - template (); + } + + template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v>, @@ -98,32 +148,40 @@ struct GemmPipelineAGmemBGmemCRegV2 auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + auto as_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); // A LDS tile window for store auto a_copy_lds_window = make_tile_window(a_lds_block, make_tuple(number{}, number{}), {0, 0}, - a_copy_dram_window.get_tile_distribution()); + as_copy_dram_window[number<0>{}].get_tile_distribution()); // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + auto bs_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); // B LDS tile window for store auto b_copy_lds_window = make_tile_window(b_lds_block, make_tuple(number{}, number{}), {0, 0}, - b_copy_dram_window.get_tile_distribution()); + bs_copy_dram_window[number<0>{}].get_tile_distribution()); // Block GEMM constexpr auto block_gemm = Policy::template GetBlockGemm(); @@ -153,28 +211,30 @@ struct GemmPipelineAGmemBGmemCRegV2 // prefetch // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - auto b_block_tile = load_tile(b_copy_dram_window); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); { // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); // global read 1 - a_block_tile = load_tile(a_copy_dram_window); + elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); // LDS write 0 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); // global read 1 - b_block_tile = load_tile(b_copy_dram_window); + elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); } index_t iCounter = num_loop - 2; @@ -189,20 +249,18 @@ struct GemmPipelineAGmemBGmemCRegV2 block_sync_lds(); // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); // global read i + 2 - a_block_tile = load_tile(a_copy_dram_window); + elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); // LDS write i + 1 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); // global read i + 2 - b_block_tile = load_tile(b_copy_dram_window); + elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); iCounter--; @@ -218,11 +276,9 @@ struct GemmPipelineAGmemBGmemCRegV2 block_sync_lds(); // LDS write num_loop - 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); block_sync_lds(); @@ -241,12 +297,28 @@ struct GemmPipelineAGmemBGmemCRegV2 { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType & b) { return b; }, + [](auto& e, const BDataType & b) { e = b; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 52bd07c9e2..c73fa29245 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -5,16 +5,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { -template @@ -22,18 +25,49 @@ struct GemmPipelineProblemBase { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; // actually AccDataType - using ComputeDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; // actually AccDataType static constexpr bool FixedVectorSize = FixedVectorSize_; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr bool ComputeDataTypeIsTuple = is_detected::value; + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + static constexpr bool ALayoutIsTuple = is_detected::value; + static constexpr bool BLayoutIsTuple = is_detected::value; + + using ComputeDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + using AsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + using BsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ComputeDataType = remove_cvref_t{}, ComputeDataTypeTuple>>; + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using ALayout = remove_cvref_t{}, AsLayoutTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + using BLayout = remove_cvref_t{}, BsLayoutTuple>>; static constexpr bool TransposeC = Traits::TransposeC; static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; @@ -66,7 +100,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; @@ -84,7 +118,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; @@ -125,7 +159,7 @@ struct GemmPipelineProblemBase { return VectorSizeA_; } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return kPadK ? 1 : GetAlignmentA(); } @@ -140,7 +174,7 @@ struct GemmPipelineProblemBase { return VectorSizeB_; } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return kPadN ? 1 : GetAlignmentB(); } @@ -161,35 +195,40 @@ struct GemmPipelineProblemBase }(); }; -// Alias for GemmPipelineProblem -template -using GemmPipelineProblem = GemmPipelineProblemBase; -template @@ -197,18 +236,48 @@ struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; // actually AccDataType - using ComputeDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; // actually AccDataType + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; static constexpr bool FixedVectorSize = FixedVectorSize_; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr bool ComputeDataTypeIsTuple = is_detected::value; + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + static constexpr bool ALayoutIsTuple = is_detected::value; + static constexpr bool BLayoutIsTuple = is_detected::value; + + using ComputeDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + using AsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + using BsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ComputeDataType = remove_cvref_t{}, ComputeDataTypeTuple>>; + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using ALayout = remove_cvref_t{}, AsLayoutTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + using BLayout = remove_cvref_t{}, BsLayoutTuple>>; static constexpr bool TransposeC = Traits::TransposeC; static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 8d47ab878e..c8f874acd6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -356,11 +356,14 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { - using ALayout = remove_cvref_t; - using ADataType = remove_cvref_t; + using AsLayout = remove_cvref_t; + using AsDataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using ALayout = remove_cvref_t{}, AsLayout>>; + using ADataType = remove_cvref_t{}, AsDataType>>; + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { - using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; + using BsLayout = remove_cvref_t; + using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using BLayout = remove_cvref_t{}, BsLayout>>; + using BDataType = remove_cvref_t{}, BsDataType>>; + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { - using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -491,6 +495,8 @@ struct UniversalGemmBasePolicy Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + using ALayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) { @@ -518,8 +524,6 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() { - using BLayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -527,6 +531,8 @@ struct UniversalGemmBasePolicy Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { @@ -554,7 +560,8 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution() { - using ALayout = remove_cvref_t; + using ALayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; static_assert(std::is_same_v); constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; @@ -574,7 +581,8 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution() { - using BLayout = remove_cvref_t; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; static_assert(std::is_same_v); constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 64900c9a97..96203b2cd2 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -10,8 +10,8 @@ namespace ck_tile { template struct TileGemmTraits @@ -23,9 +23,9 @@ struct TileGemmTraits // TODO this can't be hardcoded here! Should be in policy! static constexpr int _VectorSize = 16; - using ALayout = ALayout_; - using BLayout = BLayout_; - using CLayout = CLayout_; + using AsLayout = AsLayout_; + using BsLayout = BsLayout_; + using CLayout = CLayout_; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; @@ -36,8 +36,8 @@ template @@ -76,8 +76,8 @@ using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits { - using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockWeightPreshuffle = remove_cvref_t())>; @@ -188,7 +197,12 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 } } - template + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -455,7 +469,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 return c_block_tile; } - template + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + [[maybe_unused]] const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp[number<0>{}], + [](const ADataType & a) { return a; }, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, @@ -463,7 +503,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](auto& e, const ADataType & a) { e = a; }, b_flat_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 129eac6557..356ad91448 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -53,14 +53,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 { using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockWeightPreshuffle = remove_cvref_t())>; @@ -502,7 +511,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 template + typename AElementFunction, + typename std::enable_if_t::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -1001,8 +1013,37 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 return c_block_tile; } + // called from universal gemm kernel + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + [[maybe_unused]] const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + return operator()( + a_dram_block_window_tmp[number<0>{}], + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem_ping, + p_smem_pong); + } + // called from general gemm kernel - template + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, @@ -1019,9 +1060,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } // called from grouped gemm kernel - template + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_flat_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, TailNumber tail_number, void* __restrict__ p_smem_0, diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp index 44c6cd66c6..f505efe4e0 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp @@ -44,6 +44,10 @@ struct TileGemmQuantTraits using AQLayout = AQLayout_; using BQLayout = BQLayout_; + // TODO: It should be replaced to single value + using AsLayout = ALayout_; + using BsLayout = BLayout_; + static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr index_t NumWaveGroups = 1; diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 9314d4b795..b08f0d8316 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(gemm_multi_d) +add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_streamk) add_subdirectory(data_type) add_subdirectory(container) diff --git a/test/ck_tile/gemm_multi_abd/CMakeLists.txt b/test/ck_tile/gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..ac3b59d5d3 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/CMakeLists.txt @@ -0,0 +1,12 @@ +# Currently ck_tile is only built on gfx9 +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") + add_gtest_executable(test_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) + add_gtest_executable(test_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) + target_compile_definitions(test_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_definitions(test_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp new file mode 100644 index 0000000000..9821963458 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_abd_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // Has cshuffle epilogue enabled + // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes); + +#include "test_gemm_multi_abd_ut_cases_cshuffle.inc" diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp new file mode 100644 index 0000000000..b3a89aba05 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_abd_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // Has cshuffle epilogue disabled + // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes); + +#include "test_gemm_multi_abd_ut_cases_default2d.inc" diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc new file mode 100644 index 0000000000..5aa113608f --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc new file mode 100644 index 0000000000..cc7603164c --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp new file mode 100644 index 0000000000..428bed4e25 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -0,0 +1,500 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +struct AddScale +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const + { + a = scale * (ck_tile::type_convert(a0) + ck_tile::type_convert(a1)); + } + + float scale = 1.0; +}; + +struct MultiplyMultiply +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +struct ElementWiseAddAdd +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) + ck_tile::type_convert(d0) + + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +class TestCkTileGemmMultiABD : public ::testing::Test +{ + protected: + using A0Layout = std::tuple_element_t<0, Tuple>; + using A1Layout = std::tuple_element_t<1, Tuple>; + using B0Layout = std::tuple_element_t<2, Tuple>; + using B1Layout = std::tuple_element_t<3, Tuple>; + using D0Layout = std::tuple_element_t<4, Tuple>; + using D1Layout = std::tuple_element_t<5, Tuple>; + using ELayout = std::tuple_element_t<6, Tuple>; + using A0DataType = std::tuple_element_t<7, Tuple>; + using A1DataType = std::tuple_element_t<8, Tuple>; + using B0DataType = std::tuple_element_t<9, Tuple>; + using B1DataType = std::tuple_element_t<10, Tuple>; + using D0DataType = std::tuple_element_t<11, Tuple>; + using D1DataType = std::tuple_element_t<12, Tuple>; + using AccDataType = std::tuple_element_t<13, Tuple>; + using EDataType = std::tuple_element_t<14, Tuple>; + using AElementWiseFn = std::tuple_element_t<15, Tuple>; + using BElementWiseFn = std::tuple_element_t<16, Tuple>; + using CDElementWiseFn = std::tuple_element_t<17, Tuple>; + using UseCshuffleEpilog = std::tuple_element_t<18, Tuple>; + + using AsLayout = ck_tile::tuple; + using AsDataType = ck_tile::tuple; + using BsLayout = ck_tile::tuple; + using BsDataType = ck_tile::tuple; + using DsLayout = ck_tile::tuple; + using DsDataType = ck_tile::tuple; + + template + void invoke_gemm_multi_abd(const ck_tile::GemmMultiABDHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; + + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using GemmEpilogue = std:: + conditional_t; + + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + std::cout << "Run without SplitK" << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + std::cout << "Run using SplitK" << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + } + + public: + bool Run(const int M, + const int N, + const int K, + const int k_batch, + int StrideA0 = 0, + int StrideA1 = 0, + int StrideB0 = 0, + int StrideB1 = 0, + int StrideD0 = 0, + int StrideD1 = 0, + int StrideE = 0) + { + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA0 = f_get_default_stride(M, K, StrideA0, A0Layout{}); + StrideA1 = f_get_default_stride(M, K, StrideA1, A1Layout{}); + + StrideB0 = f_get_default_stride(K, N, StrideB0, B0Layout{}); + StrideB1 = f_get_default_stride(K, N, StrideB1, B1Layout{}); + + StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{}); + StrideD1 = f_get_default_stride(M, N, StrideD1, D1Layout{}); + + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + + ck_tile::HostTensor a0_m_k_tesnor( + f_host_tensor_descriptor(M, K, StrideA0, A0Layout{})); + ck_tile::HostTensor a1_m_k_tesnor( + f_host_tensor_descriptor(M, K, StrideA1, A1Layout{})); + + ck_tile::HostTensor b0_k_n_tensors( + f_host_tensor_descriptor(K, N, StrideB0, B0Layout{})); + ck_tile::HostTensor b1_k_n_tensors( + f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); + + ck_tile::HostTensor d0_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + ck_tile::HostTensor d1_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + + ck_tile::HostTensor e_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a0_m_k_tesnor); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a1_m_k_tesnor); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(b0_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b1_k_n_tensors); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data()); + a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data()); + + b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data()); + b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data()); + + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(), + a1_m_k_dev_buf.GetDeviceBuffer()}; + + std::array bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(), + b1_k_n_dev_buf.GetDeviceBuffer()}; + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + + std::array strideAs = {StrideA0, StrideA1}; + std::array strideBs = {StrideB0, StrideB1}; + std::array strideDs = {StrideD0, StrideD1}; + + ck_tile::GemmMultiABDHostArgs + args({as_ptr_buf, + bs_ptr_buf, + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + k_batch, + M, + N, + K, + strideAs, + strideBs, + strideDs, + StrideE}); + + invoke_gemm_multi_abd(args, ck_tile::stream_config{nullptr, false}); + + std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA0 =" << StrideA0 << " StrideA1 =" << StrideA1 + << " StrideB0 =" << StrideB0 << " StrideB1 =" << StrideB1 + << " StrideE =" << StrideE << " StrideD0 =" << StrideD0 + << " StrideD1 =" << StrideD1 << std::endl; + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + bool pass = true; + + ck_tile::HostTensor a_m_k_host_ref_element_result( + f_host_tensor_descriptor(M, K, StrideA0, A0Layout{})); + ck_tile::HostTensor b_k_n_host_ref_element_result( + f_host_tensor_descriptor(K, N, StrideB0, B0Layout{})); + ck_tile::HostTensor e_m_n_host_ref( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + a_m_k_host_ref_element_result.SetZero(); + b_k_n_host_ref_element_result.SetZero(); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_abd({a0_m_k_tesnor, a1_m_k_tesnor}, + {b0_k_n_tensors, b1_k_n_tensors}, + {d0_m_n_tensors, d1_m_n_tensors}, + a_m_k_host_ref_element_result, + b_k_n_host_ref_element_result, + e_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + K, k_batch, max_accumulated_value); + pass = ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + + return pass; + } +};