diff --git a/CMakeLists.txt b/CMakeLists.txt index 86ed3c96ec..5f4f6a52fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,6 +42,8 @@ option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) +option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) +option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) if(CK_EXPERIMENTAL_BUILDER) add_definitions(-DCK_EXPERIMENTAL_BUILDER) @@ -232,12 +234,12 @@ message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}" # Cache SUPPORTED_GPU_TARGETS for debug set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets") -if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12" AND NOT FORCE_DISABLE_XDL) message(STATUS "Enabling XDL instances") add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND NOT FORCE_DISABLE_XDL) message(STATUS "Enabling XDL FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") @@ -250,7 +252,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10") add_definitions(-DCK_GFX1030_SUPPORT) endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA) message(STATUS "Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") @@ -260,7 +262,7 @@ endif() # define the macro with the current value (0 or 1) add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA}) -if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA) message(STATUS "Enabling WMMA FP8 gemms on native architectures") add_definitions(-DCK_USE_WMMA_FP8) set(CK_USE_WMMA_FP8 "ON") diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index b1e3d86971..ce41c3310f 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -37,6 +37,13 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() +add_custom_target(example_grouped_gemm_wmma) +add_example_executable(example_grouped_gemm_wmma_splitk_fp16 grouped_gemm_wmma_splitk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_fp16) + +add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16) + list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp new file mode 100644 index 0000000000..e4da397c23 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/ignore.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; + +// clang-format on + +#define EXAMPLE_USE_SPLITK +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp new file mode 100644 index 0000000000..d5b2205892 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp @@ -0,0 +1,71 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/ignore.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; + +// clang-format on + +#define EXAMPLE_USE_SPLITK +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 0e64fbb7c6..764b533455 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -19,6 +19,10 @@ struct ProblemSize final std::vector stride_Cs; ck::index_t group_count; + +#if defined(EXAMPLE_USE_SPLITK) + ck::index_t k_batch; +#endif }; struct ExecutionConfig final @@ -177,6 +181,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto argument = gemm.MakeArgument( p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); +#if defined(EXAMPLE_USE_SPLITK) + gemm.SetKBatchSize(&argument, problem_size.k_batch); +#endif + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument); std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument); @@ -285,12 +293,15 @@ bool run_grouped_gemm_example(int argc, char* argv[]) ExecutionConfig config; problem_size.group_count = 16; +#if defined(EXAMPLE_USE_SPLITK) + problem_size.k_batch = 1; +#endif if(argc == 1) { // use default cases } - else if(argc == 4 || argc == 6) + else if(argc == 4 || argc == 6 || argc == 7) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); @@ -300,6 +311,13 @@ bool run_grouped_gemm_example(int argc, char* argv[]) config.async_hargs = std::stoi(argv[4]); problem_size.group_count = std::stoi(argv[5]); } + +#if defined(EXAMPLE_USE_SPLITK) + if(argc == 7) + { + problem_size.k_batch = std::stoi(argv[6]); + } +#endif } else { @@ -307,7 +325,10 @@ bool run_grouped_gemm_example(int argc, char* argv[]) printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: async hargs (0=n0, 1=yes)\n"); - printf("arg5: group count (default=16)"); + printf("arg5: group count (default=16)\n"); +#if defined(EXAMPLE_USE_SPLITK) + printf("arg6: k-batch count (default=1)\n"); +#endif exit(1); } diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index a42f7170aa..ec623db6f7 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -199,7 +199,7 @@ struct BaseArgument BaseArgument(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default; - virtual ~BaseArgument() {} + virtual __host__ __device__ ~BaseArgument() {} void* p_workspace_ = nullptr; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp new file mode 100644 index 0000000000..2f0c047167 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -0,0 +1,827 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_gemm_wmma_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + // Binary search lookup to find which group this block is part of + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + // NOTE: Local copy of the arg struct since SplitKBatchOffset verifies and modifies K index + // and thus needs a non-const reference. It's also not feasible to store this in global + // memory as different threads would be writing different K values to the same arg struct + auto karg = gemm_desc_ptr[group_id].karg_; + +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + const auto& block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_; + + // Tile index first dimension is the K batch + auto tile_index = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + auto splitk_batch_offset = + typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run(static_cast(p_shared), + splitk_batch_offset, + karg, + block_2_ctile_map, + epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = gemm_descs_const; + ignore = group_count; +#endif // end of if(defined(__gfx11__) || defined(__gfx12__)) +} + +template +struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static_assert(KPerBlock % AK1 == 0); + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA not supported by DeviceBatchedGemm base class. + false>; // PermuteB not supported by DeviceBatchedGemm base class. + + using CGridDesc_M_N = + remove_cvref_t( + 1, 1, 1, 1, 1))>; + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using KernelArgument = typename GridwiseGemm::Argument; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + template + struct GemmTransKernelArgBase + { + KernelArgument_ karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArgBase() = default; + GemmTransKernelArgBase(KernelArgument_&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + using GemmTransKernelArg = GemmTransKernelArgBase; + + static constexpr index_t DefaultKBatch = 1; + + static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) + { + index_t k_grain = karg.KBatch * KPerBlock; + index_t K_split = (karg.K + k_grain - 1) / karg.KBatch; + return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + } + + // Argument + // TODO: Add A/B/CDE element op? + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs) + : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs, + index_t kbatch) + : K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr} + { + grid_size_ = 0; + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_kernel_args_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_c = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + + const auto c_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + M, m_padded, N, n_padded, stride_c); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = KernelArgument(std::array{p_As[i]}, + std::array{p_Bs[i]}, + std::array{}, // p_ds_grid_ + type_convert(p_Es[i]), + M, + N, + K, + std::array{stride_a}, + std::array{stride_b}, + std::array{}, // StrideDs_ + stride_c, + K_BATCH, + PassThrough{}, + PassThrough{}, + PassThrough{}, + false); + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + } + } + + /** + * @brief Recalculate group grid size for all gemms and update B2C maps. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_read = GridwiseGemm::CalculateKRead(karg.K, K_BATCH); + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t ak0_padded = GridwiseGemm::CalculateAK0Padded(karg.K, K_BATCH); + const index_t bk0_padded = GridwiseGemm::CalculateBK0Padded(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideE); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + karg.KRead = k_read; + karg.KPadded = k_padded; + karg.AK0 = ak0_padded; + karg.BK0 = bk0_padded; + karg.KBatch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + } + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + + std::vector gemm_kernel_args_; + void* gemm_kernel_host_args_; + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) + { + using GemmTransKernelArg_ = GemmTransKernelArgBase; + static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg)); + + bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.KBatch > 1; + bool all_have_main_k0_block_loop = + CalculateHasMainKBlockLoop(arg.gemm_kernel_args_[0].karg_); + + bool not_all_have_main_k0_block_loop_same = false; + bool not_all_have_kbatch_value_same = false; + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& karg = reinterpret_cast( + arg.gemm_kernel_args_[i].karg_); + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + auto kbatch = karg.KBatch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + not_all_have_main_k0_block_loop_same |= + all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(karg); + not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (kbatch > 1); + } + + if(not_all_have_main_k0_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + // If the user provides copy stream and copy event, we assume that they're also + // responsible for providing allocated host memory (eg. pinned) which + // would be used to copy kernel arguments to the device. + if(cpy_stream && cpy_event) + { + if(arg.gemm_kernel_host_args_ == nullptr) + { + std::ostringstream err; + err << "No memory has been allocated for gemm kernel host args " + << "when providing the copy stream and copy event! In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + hip_check_error(hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_host_args_, + arg.group_count_ * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + cpy_stream)); + hip_check_error(hipEventRecord(cpy_event, cpy_stream)); + hip_check_error(hipEventSynchronize(cpy_event)); + } + else // In this case CK owns memory allocated on host. + { + + hip_check_error( + hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + } + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(all_have_kbatch_gt_one) + { + for(const auto& trans_arg : arg.gemm_kernel_args_) + { + const auto& karg = trans_arg.karg_; + hip_check_error(hipMemsetAsync(karg.p_e_grid, + 0, + karg.M * karg.N * sizeof(EDataType), + stream_config.stream_id_)); + } + } + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_kernel_args_.size()); + }; + + // NOTE: If at least one gemm problem has a main k0 block loop, we include it for all + if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.K_BATCH > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& a = arg.gemm_kernel_args_[i].karg_; + bool group_arg_valid = GridwiseGemm::CheckValidity(a); + + if(not group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + a.Print(); + } + } + supported = supported && group_arg_valid; + } + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) + { + return Argument{p_As, p_Bs, p_Es, gemm_descs}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) override + { + return std::make_unique(p_As, p_Bs, p_Es, gemm_descs); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedGemm_WmmaSplitK" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } + + // TODO: deperecation notice. + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + // polymorphic + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->UpdateKBatch(kbatch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + /// + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_kernel_args_.begin(), + pArg_->gemm_kernel_args_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 2f3555e33e..6629be2511 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -470,9 +470,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 DsGridPointer p_ds_grid; EDataType* p_e_grid; - const AElementwiseOperation a_element_op; - const BElementwiseOperation b_element_op; - const CDEElementwiseOperation cde_element_op; + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CDEElementwiseOperation cde_element_op; // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd bool is_reduce; @@ -555,13 +555,17 @@ struct GridwiseGemm_wmma_cshuffle_v3 template + typename Block2CTileMap, + typename EpilogueArgument, + int BlockMapMBlockIndex = 0, + int BlockMapNBlockIndex = 1> __device__ static void Run(AsGridPointer& p_as_grid, BsGridPointer& p_bs_grid, DsGridPointer& p_ds_grid, EDataType* p_e_grid, void* p_shared, const Problem& problem, + const Block2CTileMap& block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, @@ -582,9 +586,6 @@ struct GridwiseGemm_wmma_cshuffle_v3 MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, problem.MBlock, problem.NBlock); - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -596,8 +597,10 @@ struct GridwiseGemm_wmma_cshuffle_v3 return; } - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_m_id = + __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); + const index_t block_n_id = + __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); // BScale struct (Empty) using BScale = typename BlockwiseGemmPipe::Empty; @@ -632,15 +635,51 @@ struct GridwiseGemm_wmma_cshuffle_v3 epilogue_args); } + template + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + DsGridPointer& p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + EpilogueArgument& epilogue_args) + { + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + problem, + DefaultBlock2CTileMap(problem), + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + } + // Wrapper function to have __global__ function in common // between gemm_universal, b_scale, ab_scale, etc. template + typename Block2CTileMap, + typename EpilogueArgument, + int BlockMapMBlockIndex = 0, + int BlockMapNBlockIndex = 1> __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg, + const Block2CTileMap& block_2_ctile_map, EpilogueArgument& epilogue_args) { // shift A matrices pointer for splitk @@ -659,17 +698,47 @@ struct GridwiseGemm_wmma_cshuffle_v3 splitk_batch_offset.b_k_split_offset[i]; }); - Run( - p_as_grid_splitk, - p_bs_grid_splitk, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); + Run(p_as_grid_splitk, + p_bs_grid_splitk, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg, + block_2_ctile_map, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } + + // Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void Run(void* p_shared, + const SplitKBatchOffset& splitk_batch_offset, + Argument& karg, + EpilogueArgument& epilogue_args) + { + Run( + p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args); + } + + __device__ static auto DefaultBlock2CTileMap(const Problem& problem) + { + return Block2CTileMap{problem.M, problem.N, 4}; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 80103cfdb3..4de3a35b3e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -729,6 +729,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value too low for combination of AK1/BK1/KBatch. AK1: " + << AK1Number << ", BK1: " << BK1Number << ", KBatch: " << karg.KBatch + << ", K: " << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index bf8536d268..4d9c09f597 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -15,6 +15,142 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__) +void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); +#endif +#if defined(CK_ENABLE_BF16) +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& instances); +#endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA + #if defined(CK_USE_XDL) #if defined(CK_ENABLE_FP16) void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( @@ -409,6 +545,81 @@ struct DeviceOperationInstanceFactory> op_ptrs; +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(op_ptrs); + } + } +#endif +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA + #if defined(CK_USE_XDL) #if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp new file mode 100644 index 0000000000..6d5da9208b --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/loop_scheduler.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AccDataType = F32; +using DsDataType = Empty_Tuple; + +using DsLayout = Empty_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1; +static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3; +static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave; +static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = device::GemmSpecialization::Default; + +// Instances for 2 byte datatypes in CRR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_km_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang`-format on + >; + +// Instances for 2 byte datatypes in CCR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Helper function to add a list of layout instances with specific A/B/E datatypes for all supported +// padding/scheduler/pipeline version combinations +template + typename LayoutInstances, + typename ADataType, // NOTE: type parameters as last so that they can be inferred from the + typename BDataType, // vector argument + typename EDataType> +void add_device_grouped_gemm_wmma_universal_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + LayoutInstances{}); + add_device_operation_instances(instances, + LayoutInstances{}); + add_device_operation_instances(instances, + LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); +} + +// Helper function to add a list of layout instances for instances with matching A/B/E data types +// for all supported padding/scheduler/pipeline version combinations +template + typename LayoutInstances> +void add_device_grouped_gemm_wmma_universal_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index efb6b2580a..7c64fa7850 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -57,7 +57,7 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Do not build XDL instances if gfx9 targets are not on the target list - if(NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl") + if(((NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_XDL) AND source_name MATCHES "_xdl") message(DEBUG "removing xdl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -67,7 +67,7 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Do not build WMMA instances if gfx11 targets are not on the target list - if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") + if(((NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_WMMA) AND source_name MATCHES "_wmma") message(DEBUG "removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -88,7 +88,7 @@ function(add_instance_library INSTANCE_NAME) endif() endif() # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ - if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_") + if((NOT INST_TARGETS MATCHES "gfx12" OR FORCE_DISABLE_WMMA) AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -274,7 +274,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found only dl instances, but DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12")) + if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12" OR FORCE_DISABLE_XDL)) message(DEBUG "Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -282,7 +282,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found only MX instances, but gfx950 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (((NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) OR FORCE_DISABLE_WMMA)) message(DEBUG "Found only wmma instances, but gfx11 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -290,7 +290,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND ((NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12") OR (FORCE_DISABLE_XDL AND FORCE_DISABLE_WMMA))) message(DEBUG "Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") set(add_inst 0) endif() @@ -333,20 +333,22 @@ FOREACH(subdir_path ${dir_list}) if((add_inst EQUAL 1)) get_filename_component(target_dir ${subdir_path} NAME) add_subdirectory(${target_dir}) - if("${cmake_instance}" MATCHES "gemm") - list(APPEND CK_DEVICE_GEMM_INSTANCES $) - elseif("${cmake_instance}" MATCHES "conv") - list(APPEND CK_DEVICE_CONV_INSTANCES $) - elseif("${cmake_instance}" MATCHES "mha") - list(APPEND CK_DEVICE_MHA_INSTANCES $) - elseif("${cmake_instance}" MATCHES "contr") - list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $) - elseif("${cmake_instance}" MATCHES "reduce") - list(APPEND CK_DEVICE_REDUCTION_INSTANCES $) - else() - list(APPEND CK_DEVICE_OTHER_INSTANCES $) - endif() - message(DEBUG "add_instance_directory ${subdir_path}") + if (TARGET device_${target_dir}_instance) + if("${cmake_instance}" MATCHES "gemm") + list(APPEND CK_DEVICE_GEMM_INSTANCES $) + elseif("${cmake_instance}" MATCHES "conv") + list(APPEND CK_DEVICE_CONV_INSTANCES $) + elseif("${cmake_instance}" MATCHES "mha") + list(APPEND CK_DEVICE_MHA_INSTANCES $) + elseif("${cmake_instance}" MATCHES "contr") + list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $) + elseif("${cmake_instance}" MATCHES "reduce") + list(APPEND CK_DEVICE_REDUCTION_INSTANCES $) + else() + list(APPEND CK_DEVICE_OTHER_INSTANCES $) + endif() + message(DEBUG "add_instance_directory ${subdir_path}") + endif() else() message(DEBUG "skip_instance_directory ${subdir_path}") endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index b2e7e51a3c..ba54c6ffb3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -36,4 +36,17 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_nk_mn_instance.cpp + + device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp + + device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp + + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..6f8b31e663 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Col, + Row, + device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..2839890dcf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Col, + Col, + device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..c41dbdfc7b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Row, + Row, + device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..55d1163900 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Row, + Col, + device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..ea7eb0d615 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + F16, + Col, + Row, + device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..816188c7ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + F16, + Col, + Col, + device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..6680002d47 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + F16, + Row, + Row, + device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..3e82899834 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + F16, + Row, + Col, + device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..e93e9dff4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ADataType = F16; +using BDataType = F8; +using EDataType = F16; + +template +using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + Row, + Row, + device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..e8f043d1f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ADataType = F8; +using BDataType = F16; +using EDataType = F16; + +template +using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + Row, + Row, + device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 03a2ed3186..0ee0ee4c2e 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -42,10 +42,11 @@ bool profile_grouped_gemm_impl(int do_verification, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, - const std::vector& kbatches = {}, - int n_warmup = 1, - int n_iter = 10, - int instance_index = -1) + const std::vector& kbatches = {}, + int n_warmup = 1, + int n_iter = 10, + int instance_index = -1, + bool fail_if_no_supported_instance = false) { bool pass = true; // TODO: Fixme - we do not pass compute data type here but need it @@ -225,6 +226,7 @@ bool profile_grouped_gemm_impl(int do_verification, } } // profile device GEMM instances + int instances_supporting_all_batch_sizes = 0; for(auto& gemm_ptr : op_ptrs) { auto argument_ptr = @@ -268,6 +270,7 @@ bool profile_grouped_gemm_impl(int do_verification, kbatch_list = kbatches; } + bool all_batch_sizes_supported = true; for(std::size_t j = 0; j < kbatch_list.size(); j++) { auto kbatch_curr = kbatch_list[j]; @@ -367,10 +370,30 @@ bool profile_grouped_gemm_impl(int do_verification, } else { + all_batch_sizes_supported = false; std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" << std::endl; } } + + // If all batch sizes were supported by this instance, the instance can be marked as + // 'supported' for this problem + if(all_batch_sizes_supported) + { + ++instances_supporting_all_batch_sizes; + } + } + + // Warn if not a single instance was supported + if(instances_supporting_all_batch_sizes == 0) + { + std::cout << "Warning! No instance found that supported all of the batch sizes." + << std::endl; + + if(fail_if_no_supported_instance) + { + return false; + } } if(time_kernel) @@ -384,6 +407,7 @@ bool profile_grouped_gemm_impl(int do_verification, std::cout << "grouped_gemm_instance (" << instance_index << "/" << num_kernel << "): Passed" << std::endl; } + return pass; } diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index f78ce29daa..c6b5180013 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -3,10 +3,15 @@ add_custom_target(test_grouped_gemm) -add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) - add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) + endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp index 1683e16323..56fb758f89 100644 --- a/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp +++ b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp @@ -9,6 +9,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "test_grouped_gemm_util.hpp" +#include "test_grouped_gemm_interface_xdl.hpp" class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test { diff --git a/test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp new file mode 100644 index 0000000000..a04d13c1ea --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp @@ -0,0 +1,205 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/number.hpp" +#include "profiler/profile_grouped_gemm_impl.hpp" + +namespace ck { +namespace test { + +template +struct DeviceGroupedGemmSplitkInstanceWrapper +{ + using F16 = half_t; + using F32 = float; + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + using PassThrough = tensor_operation::element_wise::PassThrough; + + using EmptyTuple = ck::Tuple<>; + + template + using S = ck::Sequence; + + template + using I = ck::Number; + + using ABlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcAccessOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcVectorDim = std::conditional_t, I<3>, I<2>>; + using ABlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<8>, I<2>>; + using ABlockLdsAddExtraM = std::conditional_t, I<1>, I<0>>; + + using BBlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcAccessOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcVectorDim = std::conditional_t, I<2>, I<3>>; + using BBlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<2>, I<8>>; + using BBlockLdsAddExtraM = std::conditional_t, I<0>, I<1>>; + + using DeviceGroupedGemmSplitKInstance = + tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< + ALayout, + BLayout, + EmptyTuple, + ELayout, + F16, + F16, + F32, + F16, + EmptyTuple, + F16, + PassThrough, + PassThrough, + PassThrough, + GemmSpec, + 1, + 128, + 128, + 128, + KPerBlock, + K1, + K1, + 16, + 16, + 8, + 4, + S<1, 4, 16, 1>, + ABlockTransferThreadClusterArrageOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim::value, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1::value, + ABlockLdsAddExtraM::value, + S<1, 4, 16, 1>, + BBlockTransferThreadClusterArrageOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim::value, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1::value, + BBlockLdsAddExtraM::value, + 1, + 1, + S<1, 16, 1, 8>, + CDEBlockTransferScalarPerVector_NPerBlock>; + + bool IsSupported(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(&argument, kbatch); + } + + return ggemm_instance.IsSupportedArgument(argument); + } + + float Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(&argument, kbatch); + } + if(kbatch > 1 && ck::is_gfx11_supported()) + { + EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument)); + return 0; + } + else + { + EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); + auto invoker = ggemm_instance.MakeInvoker(); + DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument)); + ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer()); + return invoker.Run(argument, StreamConfig{nullptr, false}); + } + } +}; + +} // namespace test +} // namespace ck diff --git a/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_splitk.cpp similarity index 62% rename from test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp rename to test/grouped_gemm/test_grouped_gemm_splitk.cpp index c237fd562e..968bea2109 100644 --- a/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp +++ b/test/grouped_gemm/test_grouped_gemm_splitk.cpp @@ -24,21 +24,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; template class TestGroupedGemm : public ck::test::TestGroupedGemm { + public: + void SetUp() override + { + ck::test::TestGroupedGemm::SetUp(); + +#if defined(CK_USE_WMMA) + // The old XDL tests didn't fail if instances were not supported, so we want to keep that + // behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a + // specific case is not supported + this->fail_if_no_supported_instances_ = + ck::is_gfx11_supported() || ck::is_gfx12_supported(); +#endif + } }; // clang-format off using KernelTypes = ::testing::Types< + +#if defined(CK_USE_WMMA) + // WWMA only. No reason to not have it for XDL, but the instance was not defined and it was not in the original test. + std::tuple< Col, Col, Row, BF16, BF16, BF16>, +#endif + +#if defined(CK_USE_XDL) && defined(__gfx9__) + // XDL only at the moment, instances for WMMA not defined + std::tuple< Row, Row, Row, BF16, I8, BF16>, + std::tuple< Row, Col, Row, BF16, I8, BF16>, +#endif + +#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__)) + std::tuple< Row, Row, Row, F8, F16, F16>, + std::tuple< Row, Row, Row, F16, F8, F16>, +#endif + std::tuple< Row, Row, Row, F16, F16, F16>, std::tuple< Row, Col, Row, F16, F16, F16>, std::tuple< Col, Row, Row, F16, F16, F16>, std::tuple< Col, Col, Row, F16, F16, F16>, + std::tuple< Row, Row, Row, BF16, BF16, BF16>, std::tuple< Row, Col, Row, BF16, BF16, BF16>, - std::tuple< Col, Row, Row, BF16, BF16, BF16>, - std::tuple< Row, Row, Row, BF16, I8, BF16>, - std::tuple< Row, Col, Row, BF16, I8, BF16>, - std::tuple< Row, Row, Row, F16, F8, F16>, - std::tuple< Row, Row, Row, F8, F16, F16> + std::tuple< Col, Row, Row, BF16, BF16, BF16> >; // clang-format on diff --git a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc index 16c4ad5909..84558c89f9 100644 --- a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -65,6 +65,13 @@ TYPED_TEST(TestGroupedGemm, MNKPadded) TYPED_TEST(TestGroupedGemm, TestLargeKBatch) { + // gfx11 does not support split-K due to missing atomic add for fp16/bf16 + // Technically, we could still run the tests for fp32, but we currently don't have instances for + // it so we disable it entirely + if(ck::is_gfx11_supported()) + GTEST_SKIP() << "Split-K not supported for FP16/BF16 on GFX11 due to missing atomic add " + "instructions"; + const std::vector Ms{188, 210}; constexpr int N = 768; constexpr int K = 4096; diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 912066ee80..6ee6465cc4 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -11,16 +11,7 @@ #include #include "ck/ck.hpp" -#include "ck/stream_config.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/utility/data_type.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/utility/number.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" extern ck::index_t param_mask; @@ -41,7 +32,7 @@ std::string serialize_range(const Range& range) return std::string(str.begin(), str.end() - 2); } -template +template class TestGroupedGemm : public testing::Test { protected: @@ -62,9 +53,26 @@ class TestGroupedGemm : public testing::Test static constexpr bool bench_ = false; // measure kernel performance static constexpr int n_warmup_ = 0; static constexpr int n_iter_ = 1; + + bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances; std::vector k_batches_; - void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; } + void SetUp() override + { + constexpr bool require_16bit_atomic_add = + std::is_same_v || std::is_same_v; + if(require_16bit_atomic_add && ck::is_gfx11_supported()) + { + // gfx11 does not support split-K due to missing atomic add for fp16/bf16 + // Technically, we could still use split-K for fp32, but we currently don't have + // instances for it so we disable it entirely + k_batches_ = {1}; + } + else + { + k_batches_ = {1, 2, 3, 5, 8}; + } + } private: template @@ -132,204 +140,31 @@ class TestGroupedGemm : public testing::Test const std::vector& StrideCs, const std::vector& kbatches) { - bool pass = ck::profiler::profile_grouped_gemm_impl(verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatches, - n_warmup_, - n_iter_, - instance_index); + bool pass = + ck::profiler::profile_grouped_gemm_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); EXPECT_TRUE(pass); } }; -template -struct DeviceGroupedGemmSplitkInstanceWrapper -{ - using F16 = half_t; - using F32 = float; - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - using PassThrough = tensor_operation::element_wise::PassThrough; - - using EmptyTuple = ck::Tuple<>; - - template - using S = ck::Sequence; - - template - using I = ck::Number; - - using ABlockTransferThreadClusterArrageOrder = - std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; - using ABlockTransferSrcAccessOrder = - std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; - using ABlockTransferSrcVectorDim = std::conditional_t, I<3>, I<2>>; - using ABlockTransferDstScalarPerVector_K1 = - std::conditional_t, I<8>, I<2>>; - using ABlockLdsAddExtraM = std::conditional_t, I<1>, I<0>>; - - using BBlockTransferThreadClusterArrageOrder = - std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; - using BBlockTransferSrcAccessOrder = - std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; - using BBlockTransferSrcVectorDim = std::conditional_t, I<2>, I<3>>; - using BBlockTransferDstScalarPerVector_K1 = - std::conditional_t, I<2>, I<8>>; - using BBlockLdsAddExtraM = std::conditional_t, I<0>, I<1>>; - - using DeviceGroupedGemmSplitKInstance = - tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< - ALayout, - BLayout, - EmptyTuple, - ELayout, - F16, - F16, - F32, - F16, - EmptyTuple, - F16, - PassThrough, - PassThrough, - PassThrough, - GemmSpec, - 1, - 128, - 128, - 128, - KPerBlock, - K1, - K1, - 16, - 16, - 8, - 4, - S<1, 4, 16, 1>, - ABlockTransferThreadClusterArrageOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim::value, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1::value, - ABlockLdsAddExtraM::value, - S<1, 4, 16, 1>, - BBlockTransferThreadClusterArrageOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim::value, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1::value, - BBlockLdsAddExtraM::value, - 1, - 1, - S<1, 16, 1, 8>, - CDEBlockTransferScalarPerVector_NPerBlock>; - - bool IsSupported(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - int kbatch = 1) const - { - std::size_t n_groups = Ms.size(); - EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && - StrideBs.size() == n_groups && StrideCs.size() == n_groups) - << "The number of groups is not consistent!"; - - std::vector gemm_descs; - - for(std::size_t i = 0; i < n_groups; ++i) - { - gemm_descs.push_back(tensor_operation::device::GemmDesc{ - Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); - } - - std::vector p_As(n_groups, nullptr); - std::vector p_Bs(n_groups, nullptr); - std::vector p_Cs(n_groups, nullptr); - auto p_Ds = std::vector>{}; - - auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; - auto argument = ggemm_instance.MakeArgument( - p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); - if(kbatch > 1) - { - ggemm_instance.SetKBatchSize(&argument, kbatch); - } - - return ggemm_instance.IsSupportedArgument(argument); - } - - float Run(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - int kbatch = 1) const - { - std::size_t n_groups = Ms.size(); - EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && - StrideBs.size() == n_groups && StrideCs.size() == n_groups) - << "The number of groups is not consistent!"; - - std::vector gemm_descs; - - for(std::size_t i = 0; i < n_groups; ++i) - { - gemm_descs.push_back(tensor_operation::device::GemmDesc{ - Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); - } - - std::vector p_As(n_groups, nullptr); - std::vector p_Bs(n_groups, nullptr); - std::vector p_Cs(n_groups, nullptr); - auto p_Ds = std::vector>{}; - - auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; - auto argument = ggemm_instance.MakeArgument( - p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); - if(kbatch > 1) - { - ggemm_instance.SetKBatchSize(&argument, kbatch); - } - if(kbatch > 1 && ck::is_gfx11_supported()) - { - EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument)); - return 0; - } - else - { - EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); - auto invoker = ggemm_instance.MakeInvoker(); - DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument)); - ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer()); - return invoker.Run(argument, StreamConfig{nullptr, false}); - } - } -}; - } // namespace test } // namespace ck