From a73a06fb1d3a8e98e12f11265bcac54ac04b7a71 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Wed, 7 Jan 2026 19:21:57 +0000 Subject: [PATCH] Merge commit 'aad4cf098511b3f58c5bd3c32e4534d438f7539c' into develop --- ..._gemm_bias_add_reduce_wmma_cshuffle_v3.hpp | 682 ++++++++++++++++++ .../device_gemm_reduce_wmma_cshuffle_v3.hpp | 8 +- ...e_grouped_gemm_wmma_splitk_cshuffle_v3.hpp | 95 ++- .../grid/epilogue_cshuffle_v3_reduce_wmma.hpp | 166 ++++- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 2 + .../device_gemm_mean_squaremean_instance.hpp | 41 ++ ...vice_grouped_gemm_wmma_splitk_instance.hpp | 99 ++- .../gpu/grouped_gemm_fastgelu.hpp | 82 +++ .../gpu/gemm_bias_add_reduce/CMakeLists.txt | 7 +- ..._f16_f16_f16_f32_f32_km_kn_mn_instance.cpp | 85 +++ ..._f16_f16_f16_f32_f32_km_nk_mn_instance.cpp | 84 +++ ..._f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp | 84 +++ ..._f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp | 81 +++ ...ersal_bf16_bf16_bf16_km_kn_mn_instance.cpp | 6 +- ...ersal_bf16_bf16_bf16_km_nk_mn_instance.cpp | 6 +- ...ersal_bf16_bf16_bf16_mk_kn_mn_instance.cpp | 6 +- ...ersal_bf16_bf16_bf16_mk_nk_mn_instance.cpp | 6 +- ...niversal_f16_f16_f16_km_kn_mn_instance.cpp | 6 +- ...niversal_f16_f16_f16_km_nk_mn_instance.cpp | 6 +- ...niversal_f16_f16_f16_mk_kn_mn_instance.cpp | 6 +- ...niversal_f16_f16_f16_mk_nk_mn_instance.cpp | 6 +- ...universal_f16_f8_f16_mk_kn_mn_instance.cpp | 11 +- ...universal_f8_f16_f16_mk_kn_mn_instance.cpp | 11 +- .../gpu/grouped_gemm_fastgelu/CMakeLists.txt | 7 +- ...elu_wmma_f16_f16_f16_km_kn_mn_instance.cpp | 37 + ...elu_wmma_f16_f16_f16_km_nk_mn_instance.cpp | 37 + ...elu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp | 38 + ...elu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp | 38 + .../profile_gemm_bias_add_reduce_impl.hpp | 148 ++-- .../profile_grouped_gemm_fastgelu_impl.hpp | 254 +------ .../profiler/profile_grouped_gemm_impl.hpp | 130 ++-- test/CMakeLists.txt | 1 + test/gemm_bias_add_reduce/CMakeLists.txt | 9 + .../test_gemm_bias_add_reduce_fp16.cpp | 106 +++ .../gemm_bias_add_reduce/test_gemm_common.hpp | 61 ++ test/grouped_gemm/CMakeLists.txt | 6 + .../test_grouped_gemm_fastgelu.cpp | 62 ++ .../test_grouped_gemm_ut_cases.inc | 11 +- test/grouped_gemm/test_grouped_gemm_util.hpp | 98 ++- 39 files changed, 2089 insertions(+), 540 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 test/gemm_bias_add_reduce/CMakeLists.txt create mode 100644 test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp create mode 100644 test/gemm_bias_add_reduce/test_gemm_common.hpp create mode 100644 test/grouped_gemm/test_grouped_gemm_fastgelu.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..c64a1d504d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -0,0 +1,682 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_bias_add_reduce_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid, + const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops, + const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops, + const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + auto epilogue_args = EpilogueType( + p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M, d0_element_op); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = p_reduces_grid; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = d0_element_op; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3 + : public DeviceGemmReduce<1, ReduceOperations::Size()> +{ + using CDEShuffleBlockTransferScalarPerVectors = Sequence; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using ReduceTrait = ReduceTrait_; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_e_grid, + const BiasDataType* p_bias_grid, + const D0DataType* p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ElementwiseOperation d0_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_e_grid_{p_e_grid}, + p_bias_grid_{p_bias_grid}, + p_d0_grid_{p_d0_grid}, + p_reduces_grid_{p_reduces_grid}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + StrideA_{StrideA}, + StrideB_{StrideB}, + StrideC_{StrideC}, + StrideC1_{StrideC1}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + d0_element_op_{d0_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + const BiasDataType* p_bias_grid_; + const D0DataType* p_d0_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t StrideA_; + index_t StrideB_; + index_t StrideC_; + index_t StrideC1_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + D0ElementwiseOperation d0_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{arg.p_bias_grid_, arg.p_d0_grid_}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{0, arg.StrideC1_}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1); + + float ave_time = 0; + + index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.KRaw_); + + const auto Run = [&](const auto& kernel) { + // Note: cache flushing not supported + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.p_reduces_grid_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.d0_element_op_); + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline setting"); + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v1 setting"); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Even) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + else if(TailNum == TailNumber::Odd) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v3 setting"); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Device implementation supports only gfx11 and gfx12! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "FP8 and BF8 not supported on gfx11! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Without padding, K must be divisible by AK1 and BK1! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{arg.p_bias_grid_, arg.p_d0_grid_}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{0, arg.StrideC1_}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t /* KBatch */ = 1) override + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasAddReduce_Wmma_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 0240fcb619..b64b72f4d4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -49,8 +49,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = - EpilogueType(p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M); + auto epilogue_args = EpilogueType(p_reduces_grid, + reduce_in_element_ops, + reduce_out_element_ops, + karg.M, + tensor_operation::element_wise::PassThrough{}); GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); @@ -188,6 +191,7 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera using ReduceTrait = ReduceTrait_ #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/utility/env.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -242,7 +243,6 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK; using KernelArgument = typename GridwiseGemm::Argument; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; template struct GemmTransKernelArgBase { @@ -274,23 +274,38 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK& p_As, std::vector& p_Bs, + std::vector>& p_Ds, std::vector& p_Es, - std::vector& gemm_descs) - : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch) + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + c_element_op, + DefaultKBatch) { // TODO: use occupancy api to calculate appropriate batch size. } Argument(std::vector& p_As, std::vector& p_Bs, + std::vector>& p_Ds, std::vector& p_Es, std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op, index_t kbatch) : K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr} { @@ -299,9 +314,11 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK(p_As.size()) && group_count_ == ck::type_convert(p_Bs.size()) && + ((NumDTensor == 0 && p_Ds.size() == 0) || + group_count_ == ck::type_convert(p_Ds.size())) && group_count_ == ck::type_convert(p_Es.size()))) { - throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size"); } gemm_kernel_args_.reserve(group_count_); @@ -320,9 +337,22 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK(stride_d_vec.size()))) + { + throw std::runtime_error("wrong! stride D mismatch"); + } + + // Copy D stride vector to fixed-size array + std::array stride_ds; + if constexpr(NumDTensor > 0) + { + std::copy(stride_d_vec.begin(), stride_d_vec.end(), stride_ds); + } const index_t m_padded = GridwiseGemm::CalculateMPadded(M); const index_t n_padded = GridwiseGemm::CalculateNPadded(N); @@ -346,19 +376,19 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK{p_As[i]}, std::array{p_Bs[i]}, - std::array{}, // p_ds_grid_ + p_Ds[i], type_convert(p_Es[i]), M, N, K, std::array{stride_a}, std::array{stride_b}, - std::array{}, // StrideDs_ + stride_ds, stride_c, K_BATCH, - PassThrough{}, - PassThrough{}, - PassThrough{}, + a_element_op, + b_element_op, + c_element_op, false); gemm_kernel_args_.emplace_back( @@ -632,6 +662,23 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK) + { + if(arg.K_BATCH > 1) + { + // Using SplitK and a C element op would require a two stage kernel where the second + // stage applies the op on the accumulated results + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "C element operators are not supported when using SplitK. Set " + "K_BATCH to 1 or remove the operator." + << std::endl; + } + return false; + } + } + if constexpr(std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { @@ -681,14 +728,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK& p_As, std::vector& p_Bs, - std::vector>&, + std::vector>& p_Ds, std::vector& p_Es, std::vector gemm_descs, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation) + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) { - return Argument{p_As, p_Bs, p_Es, gemm_descs}; + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; } static auto MakeInvoker() { return Invoker{}; } @@ -697,14 +745,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK MakeArgumentPointer(std::vector& p_As, std::vector& p_Bs, - std::vector>&, + std::vector>& p_Ds, std::vector& p_Es, std::vector& gemm_descs, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation) override + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override { - return std::make_unique(p_As, p_Bs, p_Es, gemm_descs); + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp index 942d4351b3..d1e7f35607 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp @@ -10,6 +10,7 @@ namespace ck { template const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // Thread transfer LDS to Vmem - auto cde_shuffle_block_copy_lds_and_global = - Base::template GetLDSToVmemEpilogueDescriptor( - c_ds_desc_refs, - e_grid_desc_mblock_mperblock_nblock_nperblock, - cde_element_op, - block_m_id, - block_n_id); - - // tuple of reference to C/Ds tensor buffers - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - // LDS c_reduce_block_desc_mperblock_nperblock constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, @@ -346,6 +334,68 @@ struct EpilogueReduceCShuffle }, Number{}); + // multiple Ds + constexpr auto d_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + constexpr auto ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock = generate_tuple( + [&](auto) { return d_reduce_thread_desc_mblock_mperblock_nblock_nperblock; }, + Number{}); + + constexpr auto ds_thread_buf_size = + d_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + auto c01_thread_buf = + make_static_buffer( + Number{}); + + auto ds_thread_copy_global_to_vgpr = generate_tuple( + [&](auto I) { + return ThreadwiseTensorSliceTransfer_v2< + remove_cvref_t>, + typename ReduceTrait::ReduceAccDataType_, + decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]), + remove_cvref_t< + decltype(ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I])>, + Sequence, + Sequence<0, 1, 2, 3>, + 3, + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + 1, + true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I], + make_multi_index( + I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + }, + Number{}); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + // Write E from Vgpr to Vmem + auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + typename ReduceTrait::ReduceAccDataType_, + EDataType, + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + EGlobalMemoryDataOperation, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + NumDTensor > 0 ? tensor_operation::element_wise::PassThrough{} : cde_element_op}; + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); @@ -365,15 +415,6 @@ struct EpilogueReduceCShuffle // make sure it's safe to read from LDS block_sync_lds(); - - // each block loads its C data from LDS, D from global, applies elementwise - // operation and stores result E to global - cde_shuffle_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - { c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, c_shuffle_block_buf, @@ -381,6 +422,53 @@ struct EpilogueReduceCShuffle make_tuple(I0, I0), c_reduce_thread_buf); + // Note: currently multiple Ds supports only Bias + Add. + // It needs to be generalized for other operations (currently not needed) + if constexpr(NumDTensor > 0) + { + auto& d0_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I0); + // d0 / d1 operations + d0_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], + ds_grid_buf[I0], + ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I0], + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = activation(c + bias) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + typename ReduceTrait::ReduceAccDataType_ out; + cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); + c_reduce_thread_buf(i) = out; + }); + + auto& d1_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I1); + + d1_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], + ds_grid_buf[I1], + ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I1], + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = c + c1_function(c1) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + d0_element_op(c01_thread_buf(i), c01_thread_buf(i)); + c_reduce_thread_buf(i) += c01_thread_buf(i); + }); + } + + // Write E + c_reduce_thread_copy_vgpr_to_global.Run( + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c_reduce_thread_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + // Reduction static_for<0, NumReduce, 1>{}([&](auto In) { auto& p_reduce_grid = p_reduces_grid[In]; @@ -448,14 +536,15 @@ struct EpilogueReduceCShuffle { constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_global_step); + static_for<0, NumDTensor, 1>{}([&](auto I) { + auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I); + d_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I], cde_global_step); }); // move on E - cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); + c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, cde_global_step); } }); } @@ -464,6 +553,7 @@ struct EpilogueReduceCShuffle typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops; typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops; index_t MRaw; + typename ReduceTrait::D0ElementwiseOperation_ d0_element_op; ReduceGridDesc_M reduce_grid_desc_m; }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 64f50d13df..c168ca9d18 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -897,6 +897,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_for<0, num_access, 1>{}([&](auto access_id) { + block_sync_lds(); + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp index 6d23cd8745..c448a51cfc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp @@ -19,6 +19,7 @@ namespace instance { using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>; +#if defined(CK_USE_XDL) void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( std::vector&); void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( @@ -27,6 +28,18 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f std::vector&); void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( std::vector&); +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); +#endif // CK_USE_WMMA template ::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + op_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + op_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + op_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#if defined(CK_USE_XDL) ck::tensor_operation::device::instance:: add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( op_ptrs); +#endif +#if defined(CK_USE_WMMA) + ck::tensor_operation::device::instance:: + add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + op_ptrs); +#endif } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp index d0de1c859b..27420da45e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp @@ -31,6 +31,7 @@ using S = ck::Sequence; using Empty_Tuple = ck::Tuple<>; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; using AccDataType = F32; using DsDataType = Empty_Tuple; @@ -38,10 +39,6 @@ using DsDataType = Empty_Tuple; using DsLayout = Empty_Tuple; using ELayout = Row; -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CDEElementOp = PassThrough; - static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1; static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3; static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave; @@ -54,6 +51,9 @@ template = false> using device_grouped_gemm_wmma_universal_km_kn_mn_instances = std::tuple< @@ -73,6 +73,9 @@ template = false> using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple< // clang-format off @@ -91,6 +94,9 @@ template = false> using device_grouped_gemm_wmma_universal_mk_kn_mn_instances = std::tuple< @@ -110,6 +116,9 @@ template = false> using device_grouped_gemm_wmma_universal_mk_nk_mn_instances = std::tuple< @@ -124,17 +133,38 @@ using device_grouped_gemm_wmma_universal_mk_nk_mn_instances = // clang-format on >; +// List of instance variants to add (pipeline/scheduler/padding combinations) +// Some are disabled now, can be re-enabled if needed +using InstanceVariant = + ck::Tuple; +static constexpr InstanceVariant InstanceVariants[] = { + + make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1), + // make_tuple(GemmDefault, InterwaveScheduler, PipelineV1), + make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3), + + make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV1), + // make_tuple(GemmMNKPadding, InterwaveScheduler, PipelineV1), + // make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV3), +}; + // Helper function to add a list of layout instances with specific A/B/E datatypes for all supported // padding/scheduler/pipeline version combinations template + BlockGemmPipelineVersion BlkGemmPipelineVer, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> typename LayoutInstances, typename ADataType, // NOTE: type parameters as last so that they can be inferred from the typename BDataType, // vector argument - typename EDataType> + typename EDataType, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> void add_device_grouped_gemm_wmma_universal_instances( std::vector>>& instances) { - add_device_operation_instances(instances, - LayoutInstances{}); - add_device_operation_instances(instances, - LayoutInstances{}); - add_device_operation_instances(instances, - LayoutInstances{}); - add_device_operation_instances( - instances, LayoutInstances{}); - add_device_operation_instances( - instances, LayoutInstances{}); - add_device_operation_instances( - instances, LayoutInstances{}); + // Add all instances from our instance list + static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { + constexpr auto instance = InstanceVariants[i]; + add_device_operation_instances(instances, + LayoutInstances{}), + instance.At(Number<1>{}), + instance.At(Number<2>{}), + AElementOp, + BElementOp, + CDEElementOp>{}); + }); } // Helper function to add a list of layout instances for instances with matching A/B/E data types @@ -170,8 +199,14 @@ template - typename LayoutInstances> + BlockGemmPipelineVersion BlkGemmPipelineVer, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> + typename LayoutInstances, + typename AElementOp, // NOTE: element-wise op parameters as last so that they can be + typename BElementOp, // inferred from the vector argument + typename CDEElementOp> void add_device_grouped_gemm_wmma_universal_instances( std::vector>>& instances) { - add_device_operation_instances( - instances, LayoutInstances{}); - add_device_operation_instances( - instances, LayoutInstances{}); - add_device_operation_instances( - instances, LayoutInstances{}); - add_device_operation_instances( - instances, LayoutInstances{}); - add_device_operation_instances( - instances, LayoutInstances{}); - add_device_operation_instances( - instances, LayoutInstances{}); + // Add all instances from our instance list + static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { + constexpr auto instance = InstanceVariants[i]; + add_device_operation_instances(instances, + LayoutInstances{}), + instance.At(Number<1>{}), + instance.At(Number<2>{}), + AElementOp, + BElementOp, + CDEElementOp>{}); + }); } } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp index cce97d0933..3f722cc688 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp @@ -15,6 +15,64 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) +void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 +#endif // CK_USE_WMMA + +#if defined(CK_USE_XDL) +#if defined(CK_ENABLE_FP16) void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances); +#endif // CK_ENABLE_FP16 +#endif // CK_USE_XDL // GroupedGEMM + GELU template > op_ptrs; +#if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); +#endif +#if defined(CK_USE_WMMA) + add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); +#endif +#if defined(CK_USE_WMMA) + add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); +#endif +#if defined(CK_USE_WMMA) + add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); +#endif +#if defined(CK_USE_WMMA) + add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); +#endif } } +#endif return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt index a82e95d8d1..8be1dc6b45 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -1,10 +1,15 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_bias_add_reduce_instance device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp + + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..c736fae147 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1> + // // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..a702503e7c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 8, 16, 16, 2, 8, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 8, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 2, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 8, 16, 16, 4, 2, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 2, 8, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 2, 8, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..e27cb9d630 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..a2d0e0ba9c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,81 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 32, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8, S<32, 2>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp index 6f8b31e663..9a46330ca8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances( BF16, DsDataType, BF16, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< BF16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp index 2839890dcf..3af284f088 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances( BF16, DsDataType, BF16, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< BF16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp index c41dbdfc7b..f5151d8682 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances( BF16, DsDataType, BF16, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< BF16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp index 55d1163900..7183815210 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances( BF16, DsDataType, BF16, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< BF16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp index ea7eb0d615..ff091a8a1b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp @@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances( F16, DsDataType, F16, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< F16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp index 816188c7ff..58beafc20e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp @@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances( F16, DsDataType, F16, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< F16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp index 6680002d47..6b918d5543 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp @@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances( F16, DsDataType, F16, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp index 3e82899834..fa2ef4daa3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp @@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances( F16, DsDataType, F16, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp index e93e9dff4a..acd30b6e4b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp @@ -17,7 +17,10 @@ using EDataType = F16; template + BlockGemmPipelineVersion BlkGemmPipelineVer, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances = std::tuple< // clang-format off @@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances( BDataType, DsDataType, EDataType, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp index e8f043d1f8..ee2691ce40 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp @@ -17,7 +17,10 @@ using EDataType = F16; template + BlockGemmPipelineVersion BlkGemmPipelineVer, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off @@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances( BDataType, DsDataType, EDataType, - AElementOp, - BElementOp, - CDEElementOp>>>& instances) + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_universal_instances< diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt index 1997427462..dc09107fb9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt @@ -1,10 +1,15 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_grouped_gemm_fastgelu_instance device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp + + device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instance.cpp + device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..275003a7d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + F16, + Col, + Row, + device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..3bead2e154 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + F16, + Col, + Col, + device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..708ffed9de --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + F16, + Row, + Row, + device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..dcaf830860 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + F16, + Row, + Col, + device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp index 1930cf9eb6..8561095f8d 100644 --- a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp @@ -9,6 +9,8 @@ #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/device_memory.hpp" @@ -17,40 +19,6 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F32 = float; -using F16 = ck::half_t; -using ReducePtrsGlobal = ck::Tuple; -using Div = ck::tensor_operation::element_wise::UnaryDivide; -using Identity = ck::tensor_operation::element_wise::PassThrough; -using Square = ck::tensor_operation::element_wise::UnarySquare; -using ReduceInElementOps = ck::Tuple; -using ReduceOutElementOps = ck::Tuple; - -using DeviceGemmBiasAddReduceNoOpPtr = - ck::tensor_operation::device::DeviceGemmReducePtr<1, ReducePtrsGlobal::Size()>; - -void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( - std::vector&); - -void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( - std::vector&); - -void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( - std::vector&); - -void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( - std::vector&); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - namespace ck { namespace profiler { @@ -63,7 +31,7 @@ template -void profile_gemm_bias_add_reduce_impl(int do_verification, +bool profile_gemm_bias_add_reduce_impl(int do_verification, int init_method, bool do_log, bool time_kernel, @@ -75,6 +43,8 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, int StrideC, int StrideD0) { + bool pass = true; + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { return HostTensorDescriptor({len}, {stride}); }; @@ -231,47 +201,19 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, bias_device_buf.ToDevice(bias_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data()); - // add device GEMM instances - std::vector gemm_ptrs; + // get device op instances + const auto op_ptrs = + ck::tensor_operation::device::instance::get_device_gemm_add_add_mean_squaremean_instances< + ADataType, + BDataType, + CDataType, + ALayout, + BLayout, + CLayout>(); - if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::instance:: - add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( - gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::instance:: - add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( - gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::instance:: - add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( - gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::instance:: - add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( - gemm_ptrs); - } - } + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - if(gemm_ptrs.size() <= 0) + if(op_ptrs.size() <= 0) { throw std::runtime_error("wrong! no device GEMM instance found"); } @@ -282,29 +224,29 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, float best_gb_per_sec = 0; // profile device GEMM instances - for(auto& gemm_ptr : gemm_ptrs) + for(auto& op_ptr : op_ptrs) { - auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - bias_device_buf.GetDeviceBuffer(), - {d0_device_buf.GetDeviceBuffer()}, - c_device_buf.GetDeviceBuffer(), - p_reduces, - M, - N, - K, - StrideA, - StrideB, - StrideC, - {StrideD0}, - gemm_element_ops, - {&d0_element_op}, - reduce_in_element_ops, - reduce_out_element_ops); + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + bias_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer()}, + c_device_buf.GetDeviceBuffer(), + p_reduces, + M, + N, + K, + StrideA, + StrideB, + StrideC, + {StrideD0}, + gemm_element_ops, + {&d0_element_op}, + reduce_in_element_ops, + reduce_out_element_ops); - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { // init DO, D1 to 0 reduce0_device_buf.SetZero(); @@ -313,12 +255,12 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - std::string gemm_name = gemm_ptr->GetTypeString(); + std::string gemm_name = op_ptr->GetTypeString(); std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N + + sizeof(CDataType) * M * N + sizeof(BiasDataType) * N + sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M + sizeof(ReduceDataType) * M; @@ -343,9 +285,13 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); - ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); - ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); - ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + pass = pass & ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); + pass = pass & ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); + if(!pass) + { + std::cout << op_ptr->GetTypeString() << " failed" << std::endl; + } if(do_log) { @@ -372,12 +318,14 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, } else { - std::cout << "does not support this GEMM problem" << std::endl; + std::cout << op_ptr->GetTypeString() << " does not support this GEMM problem" + << std::endl; } } std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + return pass; } } // namespace profiler diff --git a/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp index 227b494266..635af57717 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp @@ -17,6 +17,8 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "profile_grouped_gemm_impl.hpp" + namespace ck { namespace profiler { @@ -38,242 +40,30 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification, const std::vector& StrideBs, const std::vector& StrideCs) { - - bool pass = true; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - std::size_t group_count = Ms.size(); - - if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && - group_count == StrideBs.size() && group_count == StrideCs.size())) - { - throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); - } - - std::vector> a_m_k; - std::vector> b_k_n; - std::vector> c_m_n_device_results; - - for(std::size_t i = 0; i < group_count; i++) - { - a_m_k.push_back( - Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); - b_k_n.push_back( - Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); - - c_m_n_device_results.push_back( - Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i - << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - ck::utils::FillUniformDistributionIntegerValue{}(a_m_k[i]); - ck::utils::FillUniformDistributionIntegerValue{}(b_k_n[i]); - break; - default: - ck::utils::FillUniformDistribution{0.0, 1.0}(a_m_k[i]); - ck::utils::FillUniformDistribution{-0.5, 0.5}(b_k_n[i]); - } - - ck::utils::FillConstant{}(c_m_n_device_results[i]); - } - using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::FastGelu; - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - - using DeviceMemPtr = std::unique_ptr; - std::vector a_device_buf, b_device_buf, c_device_buf; - - a_device_buf.reserve(group_count); - b_device_buf.reserve(group_count); - c_device_buf.reserve(group_count); - - std::vector p_a, p_b; - std::vector p_c; - - p_a.reserve(group_count); - p_b.reserve(group_count); - p_c.reserve(group_count); - - std::vector gemm_descs; - - gemm_descs.reserve(group_count); - - for(std::size_t i = 0; i < group_count; i++) - { - a_device_buf.emplace_back( - std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); - b_device_buf.emplace_back( - std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); - c_device_buf.emplace_back(std::make_unique( - sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); - - a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); - b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); - c_device_buf[i]->SetZero(); - - gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); - - p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); - p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); - p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); - } - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm, - CLayout, - ADataType, - BDataType, - ck::Tuple<>, - CDataType, - AElementOp, - BElementOp, - CElementOp>; - - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - if(op_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device GEMM instance found"); - } - - std::string best_gemm_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - - auto p_ds = std::vector>{}; - - // profile device GEMM instances - for(auto& gemm_ptr : op_ptrs) - { - auto argument_ptr = gemm_ptr->MakeArgumentPointer( - p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); - gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - std::string gemm_name = gemm_ptr->GetTypeString(); - - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flop = 0, num_btype = 0; - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; - num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + - sizeof(CDataType) * Ms[i] * Ns[i]; - } - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << gemm_name << std::endl; - - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - - c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); - Tensor c_m_n_host_result( - f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); - - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], - b_k_n[i], - c_m_n_host_result, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - - bool group_pass = - ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result); - pass = pass && group_pass; - - std::cout << "group: " << i << " verification result: " << std::boolalpha - << group_pass << std::endl; - - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") - << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; - LogRangeAsType( - std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - } - } - } - } - else - { - std::cout << "does not support this GEMM problem" << std::endl; - } - } - - if(do_verification) - { - std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << std::endl; - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; - - return pass; + return profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + {1}); } } // namespace profiler diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 0ee0ee4c2e..a7b8e37563 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/convolution_parameter.hpp" @@ -25,13 +26,18 @@ namespace ck { namespace profiler { +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + template + typename CLayout, + typename AElementOp = PassThrough, + typename BElementOp = PassThrough, + typename CElementOp = PassThrough> bool profile_grouped_gemm_impl(int do_verification, int init_method, bool do_log, @@ -43,8 +49,8 @@ bool profile_grouped_gemm_impl(int do_verification, const std::vector& StrideBs, const std::vector& StrideCs, const std::vector& kbatches = {}, - int n_warmup = 1, - int n_iter = 10, + int n_warmup = -1, + int n_iter = -1, int instance_index = -1, bool fail_if_no_supported_instance = false) { @@ -93,7 +99,7 @@ bool profile_grouped_gemm_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + if(do_log) { std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i @@ -103,21 +109,17 @@ bool profile_grouped_gemm_impl(int do_verification, { case 0: break; case 1: - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k[i]); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n[i]); - max_abs_in_val = 2.f; + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k[i]); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n[i]); + max_abs_in_val = 5.f; break; default: - ck::utils::FillUniformDistribution{-0.5f, 0.5f}(a_m_k[i]); + ck::utils::FillUniformDistribution{0.0f, 1.0f}(a_m_k[i]); ck::utils::FillUniformDistribution{-0.5f, 0.5f}(b_k_n[i]); - max_abs_in_val = 0.5f; + max_abs_in_val = 1.0f; } } - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - const auto a_element_op = AElementOp{}; const auto b_element_op = BElementOp{}; const auto c_element_op = CElementOp{}; @@ -200,6 +202,17 @@ bool profile_grouped_gemm_impl(int do_verification, int num_kernel = 0; auto p_ds = std::vector>{}; + StreamConfig stream_config{nullptr, time_kernel}; + if(n_warmup >= 0) + { + stream_config.cold_niters_ = n_warmup; + } + + if(n_iter >= 0) + { + stream_config.nrepeat_ = n_iter; + } + if(do_verification) { for(std::size_t i = 0; i < gemm_descs.size(); i++) @@ -225,19 +238,33 @@ bool profile_grouped_gemm_impl(int do_verification, ref_invoker.Run(ref_argument); } } + + // If the user will provide not empty kbatches list, then we test predefined set of kbatch + // values. + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; + if(!kbatches.empty()) + { + kbatch_list = kbatches; + } + + // Check if the operation requested any KBatch size > 1 + bool operation_requires_splitk_support = false; + for(auto kbatch : kbatch_list) + { + if(kbatch > 1) + { + operation_requires_splitk_support = true; + break; + } + } + // profile device GEMM instances - int instances_supporting_all_batch_sizes = 0; + int instances_supported = 0; + int instances_supporting_splitk = 0; for(auto& gemm_ptr : op_ptrs) { - auto argument_ptr = - gemm_ptr->MakeArgumentPointer(p_a, - p_b, - p_ds, - p_c, - gemm_descs, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}); + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); @@ -261,16 +288,9 @@ bool profile_grouped_gemm_impl(int do_verification, std::string gemm_name = gemm_ptr->GetTypeString(); - std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; - - // If the user will provide not empty kbatches list, then we test predefined set of kbatch - // values. - if(!kbatches.empty()) - { - kbatch_list = kbatches; - } - - bool all_batch_sizes_supported = true; + // Keep track if we found any supported instance + bool any_supported_instance = false; + bool any_supported_nontrivial_kbatch = false; for(std::size_t j = 0; j < kbatch_list.size(); j++) { auto kbatch_curr = kbatch_list[j]; @@ -290,11 +310,17 @@ bool profile_grouped_gemm_impl(int do_verification, continue; } + // Keep track of which supported instances we found + any_supported_instance = true; + if(kbatch_curr > 1) + { + any_supported_nontrivial_kbatch = true; + } + for(std::size_t i = 0; i < gemm_descs.size(); i++) c_device_buf[i]->SetZero(); - invoker_ptr->Run(argument_ptr.get(), - StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + float ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config); if(do_verification) { @@ -329,7 +355,7 @@ bool profile_grouped_gemm_impl(int do_verification, } } - std::cout << "Instance: " << gemm_name << " verification " + std::cout << "Instance: " << gemm_name << "; KBatch: " << kbatch_curr << " " << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; pass = pass && instance_pass; @@ -337,10 +363,6 @@ bool profile_grouped_gemm_impl(int do_verification, if(time_kernel) { - float ave_time = - invoker_ptr->Run(argument_ptr.get(), - StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); - std::size_t flop = 0, num_btype = 0; for(std::size_t i = 0; i < gemm_descs.size(); i++) { @@ -370,24 +392,38 @@ bool profile_grouped_gemm_impl(int do_verification, } else { - all_batch_sizes_supported = false; - std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + std::cout << "Instance: " << gemm_name + << ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")" << std::endl; } } - // If all batch sizes were supported by this instance, the instance can be marked as + // If any kbatch sizes > 1 were supported by this instance, the instance can be marked as // 'supported' for this problem - if(all_batch_sizes_supported) + if(any_supported_nontrivial_kbatch) { - ++instances_supporting_all_batch_sizes; + ++instances_supporting_splitk; + } + + if(any_supported_instance) + { + ++instances_supported; } } // Warn if not a single instance was supported - if(instances_supporting_all_batch_sizes == 0) + if(instances_supported == 0) { - std::cout << "Warning! No instance found that supported all of the batch sizes." + std::cout << "Warning! No supported instance found." << std::endl; + + if(fail_if_no_supported_instance) + { + return false; + } + } + else if(operation_requires_splitk_support && instances_supporting_splitk == 0) + { + std::cout << "Warning! No instance found that supported any of the kbatch sizes." << std::endl; if(fail_if_no_supported_instance) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 81e893edf5..7521aebc74 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -258,6 +258,7 @@ add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_add) +add_subdirectory(gemm_bias_add_reduce) add_subdirectory(gemm_blockscale_wp) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_multi_abd) diff --git a/test/gemm_bias_add_reduce/CMakeLists.txt b/test/gemm_bias_add_reduce/CMakeLists.txt new file mode 100644 index 0000000000..3fa1cc3904 --- /dev/null +++ b/test/gemm_bias_add_reduce/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_gtest_executable(test_gemm_bias_add_reduce_fp16 test_gemm_bias_add_reduce_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_bias_add_reduce_fp16 PRIVATE utility device_gemm_bias_add_reduce_instance) + endif() +endif() diff --git a/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp b/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp new file mode 100644 index 0000000000..c0206e9218 --- /dev/null +++ b/test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmBiasAddReduce_FP16_MK_NK + : public ck::test::TestGemmBiasAddReduceCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmBiasAddReduce_FP16_MK_KN + : public ck::test::TestGemmBiasAddReduceCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmBiasAddReduce_FP16_KM_KN + : public ck::test::TestGemmBiasAddReduceCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmBiasAddReduce_FP16_KM_NK + : public ck::test::TestGemmBiasAddReduceCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple< F16, F16, F16, F16, F16, F32> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_NK, KernelTypes); +TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_NK, KernelTypes); + +TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 1024; + constexpr int K = 1024; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_KN, Regular) +{ + std::vector Ms{256}; + constexpr int N = 512; + constexpr int K = 1024; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_NK, Regular) +{ + std::vector Ms{256}; + constexpr int N = 1024; + constexpr int K = 1024; + + for(int M : Ms) + this->Run(M, N, K); +} diff --git a/test/gemm_bias_add_reduce/test_gemm_common.hpp b/test/gemm_bias_add_reduce/test_gemm_common.hpp new file mode 100644 index 0000000000..7c62f56843 --- /dev/null +++ b/test/gemm_bias_add_reduce/test_gemm_common.hpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_bias_add_reduce_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmBiasAddReduceCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using CDataType = std::tuple_element_t<4, Tuple>; + using BiasDataType = std::tuple_element_t<5, Tuple>; + using D0DataType = std::tuple_element_t<6, Tuple>; + using ReduceDataType = std::tuple_element_t<7, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // integer value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + + void Run(const int M, const int N, const int K) + { + bool all_success = true; + + int StrideA = std::is_same_v, Row> ? K : M; + int StrideB = std::is_same_v, Row> ? N : K; + int StrideD0 = std::is_same_v, Row> ? N : M; + int StrideC = std::is_same_v ? N : M; + + all_success = + all_success & + ck::profiler::profile_gemm_bias_add_reduce_impl( + verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, StrideD0); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index c6b5180013..450950cbd6 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -12,6 +12,12 @@ if (CK_USE_XDL OR CK_USE_WMMA) target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) endif() + + add_gtest_executable(test_grouped_gemm_fastgelu test_grouped_gemm_fastgelu.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu) + endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_fastgelu.cpp b/test/grouped_gemm/test_grouped_gemm_fastgelu.cpp new file mode 100644 index 0000000000..b792dd707d --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_fastgelu.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "gtest/gtest.h" +#include "test_grouped_gemm_util.hpp" + +ck::index_t param_mask = 0xffffff; +ck::index_t instance_index = -1; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using I8 = int8_t; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::FastGelu; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +class TestGroupedGemm : public ck::test::TestGroupedGemm +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Col, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Col, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp> +>; +// clang-format on + +TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes); + +#include "test_grouped_gemm_ut_cases.inc" +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} diff --git a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc index 84558c89f9..f0b4ee6108 100644 --- a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -65,12 +65,11 @@ TYPED_TEST(TestGroupedGemm, MNKPadded) TYPED_TEST(TestGroupedGemm, TestLargeKBatch) { - // gfx11 does not support split-K due to missing atomic add for fp16/bf16 - // Technically, we could still run the tests for fp32, but we currently don't have instances for - // it so we disable it entirely - if(ck::is_gfx11_supported()) - GTEST_SKIP() << "Split-K not supported for FP16/BF16 on GFX11 due to missing atomic add " - "instructions"; + // In some cases Split K is not supported. Running this test would fail since no instance will + // be supported, so we skip the test + if(!this->IsSplitKSupported()) + GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on " + "GFX11, or using CDE element-wise operation)"; const std::vector Ms{188, 210}; constexpr int N = 768; diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 6ee6465cc4..1fed403f2f 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -7,11 +7,14 @@ #include #include #include +#include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" extern ck::index_t param_mask; @@ -32,16 +35,46 @@ std::string serialize_range(const Range& range) return std::string(str.begin(), str.end() - 2); } +// Helper primary template (will be specialized on the boolean) +template >)> +struct tuple_element_or_impl; + +// Specialization for the in-range case: use std::tuple_element_t +template +struct tuple_element_or_impl +{ + using type = std::tuple_element_t>; +}; + +// Specialization for the out-of-range case: use Default +template +struct tuple_element_or_impl +{ + using type = Default; +}; + +// User-facing alias +template +using tuple_element_or_t = typename tuple_element_or_impl::type; + template class TestGroupedGemm : public testing::Test { protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using ELayout = std::tuple_element_t<2, Tuple>; - using ADataType = std::tuple_element_t<3, Tuple>; - using BDataType = std::tuple_element_t<4, Tuple>; - using EDataType = std::tuple_element_t<5, Tuple>; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + using AElementOp = tuple_element_or_t<6, Tuple, PassThrough>; + using BElementOp = tuple_element_or_t<7, Tuple, PassThrough>; + using CDEElementOp = tuple_element_or_t<8, Tuple, PassThrough>; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -57,15 +90,25 @@ class TestGroupedGemm : public testing::Test bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances; std::vector k_batches_; - void SetUp() override + bool IsSplitKSupported() { + // gfx11 does not support split-K due to missing atomic add for fp16/bf16 + // Technically, we could still use split-K for fp32, but we currently don't have + // instances for it so we disable it entirely constexpr bool require_16bit_atomic_add = std::is_same_v || std::is_same_v; - if(require_16bit_atomic_add && ck::is_gfx11_supported()) + bool missing_atomic_add = require_16bit_atomic_add && ck::is_gfx11_supported(); + + // CDE element operators are not supported in combination with split K + constexpr bool has_cde_element_operator = !std::is_same_v; + + return !missing_atomic_add && !has_cde_element_operator; + } + + void SetUp() override + { + if(!IsSplitKSupported()) { - // gfx11 does not support split-K due to missing atomic add for fp16/bf16 - // Technically, we could still use split-K for fp32, but we currently don't have - // instances for it so we disable it entirely k_batches_ = {1}; } else @@ -147,21 +190,24 @@ class TestGroupedGemm : public testing::Test float, ALayout, BLayout, - ELayout>(verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatches, - n_warmup_, - n_iter_, - instance_index, - fail_if_no_supported_instances_); + ELayout, + AElementOp, + BElementOp, + CDEElementOp>(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); EXPECT_TRUE(pass); } };