diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index a58612cb5b..76431cae7a 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -30,3 +30,5 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() + +add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp new file mode 100644 index 0000000000..54abab2f60 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct AddAdd +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c + d0 + d1; + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3 + // clang-format off + //#########################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| S| | | + < A0Layout, B0Layout, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD, StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp index 580f38a79f..086ea45d10 100644 --- a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp +++ b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -184,7 +184,6 @@ int main(int argc, char* argv[]) b0_device_buf.ToDevice(b0_k_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -220,11 +219,12 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50}); - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -233,8 +233,6 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - if(do_verification) { Tensor c_m_n({M, N}); diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index ef0b5286ac..3dff1b28c6 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #ifndef __HIPCC_RTC__ @@ -149,6 +149,107 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator #endif }; +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleD and DeviceGemmMultipleDSplitK is +/// that DeviceGemmMultipleDSplitK::MakeArgumentPointer requires an additional parameter +/// KBatch which is explicitly passed as 1 by this wrapper. +template +struct DeviceGemmMultipleDSplitKWrapper : public DeviceGemmMultipleD +{ + using DeviceOp = DeviceGemmMultipleDSplitK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + 1, // KBatch + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..49fa6676cd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,613 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...) +/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise +// operations that could be applied on each tensor respectively. The CDE_op is an +// elementwise operation applied to the C and all D tensors. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam DsLayout D tensors data layouts. +/// @tparam ELayout E tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam DsDataType D tensors data types. +/// @tparam EDataType E tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after +/// GEMM) and D input tensors. +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access. +/// Used when loading data from D tensors and storing data +/// to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct DeviceGemmMultipleD_Wmma_CShuffleV3 + : public DeviceGemmMultipleDSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + std::array size_ds_buffers; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + size_ds_buffers[i] = + ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + size_ds_buffers); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, + 0, + arg_.M * arg_.N * sizeof(EDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid, + 0, + arg.M * arg.N * sizeof(EDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMultipleD_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0]; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + str << std::string(DLayout::name)[0]; + }); + str << std::string(ELayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< { - // GridwiseGemm using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, + Tuple<>, // DsLayout CLayout, ADataType, BDataType, AccDataType, CShuffleDataType, + Tuple<>, // DsDataType CDataType, AElementwiseOperation, BElementwiseOperation, @@ -219,7 +220,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -294,7 +295,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, 0, arg_.M * arg_.N * sizeof(CDataType), stream_config.stream_id_)); @@ -312,7 +313,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_)); @@ -468,11 +469,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2{}, // p_ds_grid_ + p_c, + M, + N, + K, + StrideA, + StrideB, + std::array{}, // StrideDs_ + StrideC, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; } static auto MakeInvoker() { return Invoker{}; } @@ -488,20 +503,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2(static_cast(p_a), static_cast(p_b), + std::array{}, // p_ds_grid_ static_cast(p_c), M, N, K, StrideA, StrideB, + std::array{}, // StrideDs_ StrideC, - KBatch); + KBatch, + a_element_op, + b_element_op, + c_element_op); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index f3354cd5dd..9dce3f22a3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -19,7 +19,7 @@ namespace ck { template __global__ void @@ -31,22 +31,26 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) { #endif __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( + GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset, p_shared, - karg); + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op); #if defined(__gfx11__) } #endif @@ -59,9 +63,10 @@ __global__ void /// /// @par Overview /// This GEMM kernel is carrying out following mathematical equation: -/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) -/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are -/// elementwise operations that could be applied on each tensor respectively. +/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...) +/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise +// operations that could be applied on each tensor respectively. The CDE_op is an +// elementwise operation applied to the C and all D tensors. /// The \"universal\" gemm comes with multiple pipelines optimized for different usage /// scenarios. That's why it's called \"universal\". It's universal through it's design /// and versatilty. @@ -73,18 +78,20 @@ __global__ void /// /// @tparam ALayout A tensor data layout. /// @tparam BLayout B tensor data layout. -/// @tparam CLayout C tensor data layout. +/// @tparam DsLayout D tensors data layouts. +/// @tparam ELayout E tensor data layout. /// @tparam ADataType A tensor data type. /// @tparam BDataType B tensor data type. /// @tparam AccDataType The accumulation data type related to the hardware /// matrix-multiplication instruction. /// @tparam CShuffleDataType The data type used to store matrix-multiplication results into /// LDS memory during \"CShuffle\" data layout optimization. -/// @tparam CDataType C tensor data type. -/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. -/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. -/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor -/// (after GEMM). +/// @tparam DsDataType D tensors data types. +/// @tparam EDataType E tensor data type. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after +/// GEMM) and D input tensors. /// @tparam GemmSpec Determines used "padding" version. /// @tparam BlockSize The number of threads within workgroup. /// @tparam MPerBlock The input/output data tile size in the M dimension. @@ -142,11 +149,12 @@ __global__ void /// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions /// results to process per wave per iteration of CShuffle /// in N dimension. -/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial /// thread distribution used for storing data into output /// tensor across output data layout dimensions. -/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. -/// Used when storing data to output tensor. +/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access. +/// Used when loading data from D tensors and storing data +/// to output tensor. /// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or /// intrawave). /// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. @@ -160,15 +168,17 @@ __global__ void /// in global memory (pre-shuffled). template {}; static constexpr auto I7 = Number<7>{}; + static constexpr auto EShuffleBlockTransferScalarPerVector = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> static constexpr auto AK0Number = Number{}; static constexpr auto BK0Number = Number{}; @@ -530,17 +543,18 @@ struct GridwiseGemm_wmma_cshuffle_v3 return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); } + template __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE) { const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1)); } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE)); } }(); @@ -593,6 +607,44 @@ struct GridwiseGemm_wmma_cshuffle_v3 #endif } + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDEGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + struct Problem { __host__ Problem(index_t M_, @@ -600,14 +652,16 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t K_, index_t StrideA_, index_t StrideB_, - index_t StrideC_, + std::array StrideDs_, + index_t StrideE_, index_t KBatch_) : M{M_}, N{N_}, K{K_}, StrideA{StrideA_}, StrideB{StrideB_}, - StrideC{StrideC_}, + StrideDs{StrideDs_}, + StrideE{StrideE_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, @@ -627,8 +681,16 @@ struct GridwiseGemm_wmma_cshuffle_v3 << "N:" << N << ", " << "K:" << K << ", " << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " + << "SB:" << StrideB << ", "; + if constexpr(NumDTensor > 0) + { + std::cout << "SDs: { "; + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : ""); + }); + std::cout << " }, "; + } + std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " @@ -644,7 +706,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t K; index_t StrideA; index_t StrideB; - index_t StrideC; + std::array StrideDs; + index_t StrideE; index_t KBatch; index_t MPadded; index_t NPadded; @@ -661,21 +724,35 @@ struct GridwiseGemm_wmma_cshuffle_v3 { __host__ Argument(const ADataType* p_a_grid_, const BDataType* p_b_grid_, - CDataType* p_c_grid_, + std::array p_ds_grid_, + EDataType* p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, - index_t StrideC_, + std::array StrideDs_, + index_t StrideE_, index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_, bool is_reduce_ = false) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, - p_c_grid{p_c_grid_}, + p_ds_grid{}, + p_e_grid{p_e_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + cde_element_op{cde_element_op_}, is_reduce(is_reduce_) { + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); } __host__ __device__ inline bool IsReduceAdd() const @@ -690,42 +767,49 @@ struct GridwiseGemm_wmma_cshuffle_v3 const ADataType* p_a_grid; const BDataType* p_b_grid; - CDataType* p_c_grid; + DsGridPointer p_ds_grid; + EDataType* p_e_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CDEElementwiseOperation cde_element_op; + + // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd bool is_reduce; }; struct SplitKBatchOffset { - __device__ SplitKBatchOffset(Argument& karg) + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + a_k_split_offset = k_id * karg.KRead / APackedSize; } else if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + a_k_split_offset = k_id * karg.KRead * karg.StrideA; } if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + b_k_split_offset = k_id * karg.KRead * karg.StrideB; } else if constexpr(is_same_v) { if constexpr(!PermuteB) { - b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + b_k_split_offset = k_id * karg.KRead / BPackedSize; } else { const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + b_k_split_offset = k_id * k0_offset / BPackedSize; } } - if(blockIdx.z < static_cast(karg.KBatch - 1)) + if(k_id < karg.KBatch - 1) { karg.K = karg.KRead; } @@ -736,7 +820,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 if(karg.IsReduceAdd()) { - c_reduce_offset = blockIdx.z * karg.M * karg.N; + c_reduce_offset = k_id * karg.M * karg.N; } else { @@ -1143,7 +1227,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + std::cout << "Arg K value is not a multiple of K_Batch * KPerBlock! K: " << karg.K << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } @@ -1219,56 +1303,51 @@ struct GridwiseGemm_wmma_cshuffle_v3 } } - if constexpr(is_same::value) + if constexpr(is_same::value) { - if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + if(karg.N % EShuffleBlockTransferScalarPerVector != 0) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + "EShuffleBlockTransferScalarPerVector (" + << EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } return false; } } else { - if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + if(karg.M % EShuffleBlockTransferScalarPerVector != 0) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + "EShuffleBlockTransferScalarPerVector (" + << EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } return false; } } - if constexpr(!(is_same, half_t>::value || - is_same, float>::value || - is_same, bhalf_t>::value || - is_same, int32_t>::value)) + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) { - if(!karg.IsReduceAdd()) + if(karg.IsAtomicAdd() && karg.KBatch > 1) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - if(karg.KBatch > 1) - { - return false; + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this " + << "destination type (EDataType) " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; } + return false; } } @@ -1301,18 +1380,18 @@ struct GridwiseGemm_wmma_cshuffle_v3 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); } - template - __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + template + __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock) { - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( - c_grid_desc_m_n, + const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + de_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), make_unmerge_transform(make_tuple(NBlock, Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - return c_grid_desc_mblock_mperblock_nblock_nperblock; + return de_grid_desc_mblock_mperblock_nblock_nperblock; } // return block_id to C matrix tile idx (m0, n0) mapping @@ -1322,30 +1401,40 @@ struct GridwiseGemm_wmma_cshuffle_v3 template __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, - CDataType* p_c_grid, + DsGridPointer p_ds_grid, + EDataType* p_e_grid, void* p_shared, const Problem& problem, const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -1355,8 +1444,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) { return; } @@ -1483,7 +1572,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 c_thread_buf, num_k_block_main_loop); - // shuffle C and write out + // Epilogue: shuffle C for better memory access pattern, apply elementwise operation to + // C and Ds, write out result E to global memory { // C mapping in single thread. constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = @@ -1601,31 +1691,63 @@ struct GridwiseGemm_wmma_cshuffle_v3 m_thread_data_on_block_idx[I3]), ck::tensor_operation::element_wise::PassThrough{}}; - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor buffers + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy which loads C from LDS, D from global, applies elementwise + // operation and stores result E to global + auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, // ThreadGroup + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, // ElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // DstInMemOps, Sequence<1, CShuffleMRepeatPerShuffle * MWave * MPerWmma, 1, CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c_element_op}; + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // DstDimAccessOrder, + 3, // SrcVectorDim, + 3, // DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors + EShuffleBlockTransferScalarPerVector, // DstScalarPerVector + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; // space filling curve for local reg & global memory // space filling curve for threadwise C in VGPR @@ -1641,7 +1763,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 MAccVgprs>>{}; // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = + constexpr auto sfc_cde_global = SpaceFillingCurve, Sequence<0, 2, 1, 3>, Sequence<1, @@ -1651,7 +1773,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS @@ -1668,57 +1790,78 @@ struct GridwiseGemm_wmma_cshuffle_v3 // make sure it's safe to read from LDS block_sync_lds(); - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); + // each block loads its C data from LDS, D from global, applies elementwise + // operation and stores result E to global + cde_shuffle_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); if constexpr(access_id < num_access - 1) { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + // move on E + cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); } }); } } template __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, - CDataType* p_c_grid, + DsGridPointer& p_ds_grid, + EDataType* p_e_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + const auto e_grid_desc_m_n = MakeDEGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE); + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, problem.MBlock, problem.NBlock); Run(p_a_grid, p_b_grid, - p_c_grid, + p_ds_grid, + p_e_grid, p_shared, problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op); } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp index 555b52de75..c93e609b7a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_XDL) void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( std::vector>>&); +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +void add_device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>&); +#endif // CK_USE_WMMA // GEMM + Add + FastGelu +// DeviceGemmMultipleDSplitK specialization +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddFastGelu>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddFastGelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances( + op_ptrs); + } + } + + // TODO: Add other types and layouts + +#endif // CK_ENABLE_FP16 +#endif // CK_USE_WMMA + + return op_ptrs; + } +}; + +// GEMM + Add + FastGelu +// DeviceGemmMultipleD specialization template > op_ptrs; -#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) +#if defined(CK_USE_XDL) +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -143,7 +227,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif +#endif // CK_ENABLE_FP16 && CK_ENABLE_INT8 #if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) if constexpr(is_same_v && is_same_v && @@ -156,8 +240,9 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } -#endif +#endif // CK_ENABLE_BF16 && CK_ENABLE_INT8 +#if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -186,6 +271,29 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_ENABLE_FP16 +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleDSplitK instances + using Wrapper = DeviceGemmMultipleDSplitKWrapper, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddFastGelu>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 45d6abce01..13878116c2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,9 +1,11 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_fastgelu_instance - device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp - device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp + + device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp new file mode 100644 index 0000000000..52edd68752 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// e = elementwise((a * b), d0) +// elementwise(c, d0) = fastgelu(c + d0) +// output: e[m, n] +// input: a[m, k], b[n, k], d0[m, n] + +using device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +using device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmDefault, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 4f4a1f5356..2929f5a042 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -45,7 +45,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) list(APPEND PROFILER_OPS profile_gemm_add.cpp) list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm.cpp) list(APPEND PROFILER_OPS profile_gemm_streamk.cpp) list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp) @@ -86,11 +85,14 @@ if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFIN list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp) endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") list(APPEND PROFILER_OPS profile_gemm_universal.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp) + endif() endif() if(DL_KERNELS) @@ -152,7 +154,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance) list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance) list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance) @@ -202,6 +203,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance) + endif() endif() if(DL_KERNELS) diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index ab4c781847..f7430b8ae1 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -1,19 +1,24 @@ -add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp) +add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) + target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance) endif() -add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp) +add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) + target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) endif() -add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp) +add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) + target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) endif() -add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp) +add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) + target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +endif() + +add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance) endif() diff --git a/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp b/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp new file mode 100644 index 0000000000..4ac88770a1 --- /dev/null +++ b/test/gemm_add/test_gemm_add_fastgelu_wmma.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_fastgelu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAddFastgelu : public TestGemmD0Common +{ + private: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmAddFastgeluImpl = + ck::profiler::profile_gemm_add_fastgelu_impl; + + decltype(ProfileGemmAddFastgeluImpl) GetImpl() override { return ProfileGemmAddFastgeluImpl; } +}; + +using KernelTypes = ::testing::Types>; + +TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes); +TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp index 1b12ab7528..2c055a8006 100644 --- a/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp @@ -1,13 +1,13 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_fastgelu_impl.hpp" -#include "test_gemm_add_xdl.hpp" +#include "test_gemm_common.hpp" template -class TestGemmAddFastgelu : public TestGemmAdd +class TestGemmAddFastgelu : public TestGemmD0Common { private: using ADataType = std::tuple_element_t<0, Tuple>; diff --git a/test/gemm_add/test_gemm_add_relu_xdl.cpp b/test/gemm_add/test_gemm_add_relu_xdl.cpp index e8b769b1cb..35aaba96b1 100644 --- a/test/gemm_add/test_gemm_add_relu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_relu_xdl.cpp @@ -1,13 +1,13 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_relu_impl.hpp" -#include "test_gemm_add_xdl.hpp" +#include "test_gemm_common.hpp" template -class TestGemmAddRelu : public TestGemmAdd +class TestGemmAddRelu : public TestGemmD0Common { private: using ADataType = std::tuple_element_t<0, Tuple>; diff --git a/test/gemm_add/test_gemm_add_silu_xdl.cpp b/test/gemm_add/test_gemm_add_silu_xdl.cpp index 75fa59a8e7..8d242869c6 100644 --- a/test/gemm_add/test_gemm_add_silu_xdl.cpp +++ b/test/gemm_add/test_gemm_add_silu_xdl.cpp @@ -1,13 +1,13 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_silu_impl.hpp" -#include "test_gemm_add_xdl.hpp" +#include "test_gemm_common.hpp" template -class TestGemmAddSilu : public TestGemmAdd +class TestGemmAddSilu : public TestGemmD0Common { private: using ADataType = std::tuple_element_t<0, Tuple>; diff --git a/test/gemm_add/test_gemm_add_xdl.hpp b/test/gemm_add/test_gemm_add_xdl.hpp index 11d3d1c10a..3cc5405b5f 100644 --- a/test/gemm_add/test_gemm_add_xdl.hpp +++ b/test/gemm_add/test_gemm_add_xdl.hpp @@ -1,22 +1,15 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_impl.hpp" - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using I8 = int8_t; -using BF16 = ck::bhalf_t; -using F16 = ck::half_t; -using F32 = float; +#include "test_gemm_common.hpp" template -class TestGemmAdd : public ::testing::Test +class TestGemmAdd : public TestGemmD0Common { - protected: + private: using ADataType = std::tuple_element_t<0, Tuple>; using BDataType = std::tuple_element_t<1, Tuple>; using AccDataType = std::tuple_element_t<2, Tuple>; @@ -37,32 +30,7 @@ class TestGemmAdd : public ::testing::Test D0Layout, ELayout>; - virtual decltype(ProfileGemmAddImpl) GetImpl() { return ProfileGemmAddImpl; } - - void Run() - { - std::vector> lengths = { - {16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}}; - - bool all_success = true; - - for(auto length : lengths) - { - int M = length[0]; - int N = length[1]; - int K = length[2]; - int StrideA = ck::is_same_v ? K : M; - int StrideB = ck::is_same_v ? N : K; - int StrideD0 = ck::is_same_v ? N : M; - int StrideE = ck::is_same_v ? N : M; - - all_success = - all_success & - GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE); - } - - EXPECT_TRUE(all_success); - } + decltype(ProfileGemmAddImpl) GetImpl() override { return ProfileGemmAddImpl; } }; using KernelTypes = ::testing::Types, diff --git a/test/gemm_add/test_gemm_common.hpp b/test/gemm_add/test_gemm_common.hpp new file mode 100644 index 0000000000..1cf41d7538 --- /dev/null +++ b/test/gemm_add/test_gemm_common.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_impl.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +class TestGemmD0Common : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl; + + virtual decltype(ProfileGemmAddImpl) GetImpl() = 0; + + void Run() + { + std::vector> lengths = { + {16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}}; + + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideD0 = ck::is_same_v ? N : M; + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & + GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE); + } + + EXPECT_TRUE(all_success); + } +};