From a61e73bc56966a138ab1b5dadf27983800788431 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 3 Apr 2024 09:08:08 -0500 Subject: [PATCH 1/3] Add instances for conv_scale with fp8@bf8->fp8 (#1220) * Update device op api to support BComputeType * Add example * Add instances * Add profiler mode * Add client example * Update copyright year * Add BComputeType check * Fix compute types --- client_example/16_convnd_fwd/CMakeLists.txt | 5 + client_example/16_convnd_fwd/common.hpp | 8 +- .../16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp | 50 ++++++++ example/09_convnd_fwd/CMakeLists.txt | 1 + .../09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp | 83 ++++++++++++ .../device_grouped_conv_fwd_multiple_abd.hpp | 10 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 15 ++- ...ouped_conv_fwd_multiple_d_xdl_cshuffle.hpp | 10 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 39 +++--- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 7 +- .../device_grouped_conv_fwd_xdl_instance.hpp | 36 ++++++ .../gpu/grouped_convolution_forward.hpp | 118 ++++++++++++------ .../gpu/grouped_convolution_forward_xdl.inc | 18 +++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 5 + ..._ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp | 54 ++++++++ .../profile_grouped_conv_fwd_impl.hpp | 10 +- profiler/src/profile_grouped_conv_fwd.cpp | 75 ++++++----- 17 files changed, 441 insertions(+), 103 deletions(-) create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index e034c468d5..808693b632 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -17,6 +17,11 @@ if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) endif() +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp) + target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) +endif() + if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index a5b7c5b42e..ee408c7443 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -95,7 +95,8 @@ template + typename AComputeType = InDataType, + typename BComputeType = AComputeType> bool run_grouped_conv_fwd(std::array in_lengths, std::array wei_lengths, std::array out_lengths) @@ -186,7 +187,8 @@ bool run_grouped_conv_fwd(std::array; + AComputeType, + BComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp new file mode 100644 index 0000000000..8508dc9c55 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using AComputeType = ck::f8_t; +using BComputeType = ck::bf8_t; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 61e9a43c3a..afbe741212 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -5,6 +5,7 @@ add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp new file mode 100644 index 0000000000..53a12377c5 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using AComputeType = ck::f8_t; +using BComputeType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeType, + BComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp index fa3dcfdf20..31e8d639ad 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval().IsTuple()); * \tparam AElementwiseOperation A elementwise operation. * \tparam BElementwiseOperation B elementwise operation. * \tparam CDEElementwiseOperation CDE elementwise operation. - * \tparam ComputeType Compute data type (default: ADataType, first if tuple passed). + * \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed). + * \tparam BComputeType Compute data type for B tensor (default: AComputeType). */ template ::value, Number<0>, - ADataType>())> // ComputeType is InputType by default (first + ADataType>()), // AComputeType is InputType by default (first // in tuple for MultiAB), unpack if tuple was // passed + typename BComputeType = AComputeType> struct DeviceGroupedConvFwdMultipleABD : public BaseOperator { static constexpr bool isMultiA = is_detected::value; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 5ff42f98f3..f53ec8a4e8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -254,13 +254,14 @@ template ::value, Number<0>, ADataType>()), // ComputeType is InputType by default (first // in tuple for MultiAB), unpack if tuple was // passed - LoopScheduler LoopSched = make_default_loop_scheduler()> + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle : public DeviceGroupedConvFwdMultipleABD + AComputeDataType, + BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; @@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmBDataType = std::conditional_t, BDataType>; #define GridwiseGemmTemplateParameters \ - GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ @@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ - CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType // Use appropriate gridwise gemm using GridwiseGemm = std::conditional_t::value, Number<0>, ADataType>()), // ComputeType is InputType by default (first // in tuple for MultiAB), unpack if tuple was // passed - LoopScheduler LoopSched = make_default_loop_scheduler()> + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, @@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, - ComputeDataType, + AComputeDataType, + BComputeDataType, LoopSched>; } // namespace device diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 4b7cc56796..0f98f9e63d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -30,7 +30,7 @@ namespace ck { // D0, D1, ... and E have the same layout template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleABD_xdl_cshuffle { static constexpr index_t NumATensor = AsDataType::Size(); @@ -101,10 +102,13 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle decltype(GridwiseGemmPipeline_Selector())>; #if CK_WORKAROUND_DENORM_FIX - using ComputeDataType = - conditional_t, ck::bhalf_t, ComputeDataType_>; + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; #else - using ComputeDataType = ComputeDataType_; + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -195,8 +199,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ComputeDataType), + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), c_block_size * sizeof(CShuffleDataType)); } @@ -597,7 +601,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, AsDataType, - Tuple, + Tuple, decltype(as_grid_desc_ak0_m_ak1), decltype(tie(a_block_desc_ak0_m_ak1)), AElementwiseOperation, @@ -628,7 +632,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, BsDataType, - Tuple, + Tuple, decltype(bs_grid_desc_bk0_n_bk1), decltype(tie(b_block_desc_bk0_n_bk1)), BElementwiseOperation, @@ -656,14 +660,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeDataType, // ComputeDataType for A - ComputeDataType, // ComputeDataType for B + AComputeDataType, + BComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -681,10 +686,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index c0a3d29f85..6ddc3aca18 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -73,7 +73,7 @@ template + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleD_xdl_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -103,8 +103,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle #if CK_WORKAROUND_DENORM_FIX using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; #else using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 0f845ca1ed..40878e4f0e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -290,6 +290,42 @@ using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 24a5f9a5cb..e61ec28284 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -34,7 +34,8 @@ template + typename AComputeType, + typename BComputeType> struct DeviceOperationInstanceFactory> + AComputeType, + BComputeType>> { using DeviceOp = DeviceGroupedConvFwdMultipleABD; + AComputeType, + BComputeType>; static auto GetInstances() { @@ -75,14 +78,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); } @@ -94,14 +99,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); } @@ -115,14 +122,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } @@ -130,14 +139,17 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs); } @@ -149,14 +161,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); } @@ -164,7 +178,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); } @@ -176,14 +192,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); } @@ -191,7 +209,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); } @@ -203,14 +223,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); } @@ -218,14 +240,17 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); } @@ -237,7 +262,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); } @@ -245,28 +271,40 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( op_ptrs); } if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(op_ptrs); } #endif +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(op_ptrs); + } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); } @@ -274,14 +312,17 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); } @@ -295,7 +336,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); @@ -305,7 +347,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); @@ -320,7 +363,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); @@ -335,7 +379,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( @@ -347,7 +392,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); @@ -363,7 +409,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( @@ -375,7 +422,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index 942674ef99..691414ebcb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -351,6 +351,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( BF8>>>& instances); #endif +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 972fb54031..50a6ec9a45 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -41,4 +41,9 @@ if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp) endif() +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp) +endif() + add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp new file mode 100644 index 0000000000..d42104bf6b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index f629809daa..d913873305 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,7 +31,9 @@ template + typename OutDataType, + typename AComputeType = InDataType, + typename BComputeType = AComputeType> bool profile_grouped_conv_fwd_impl(int do_verification, int init_method, bool do_log, @@ -209,7 +211,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>; + OutElementOp, + AComputeType, + BComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 1f72733729..a847999b56 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -25,6 +25,7 @@ enum struct ConvDataType INT8_INT8_INT8, // 3 F8_F8_F8, // 4 BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 }; #define OP_NAME "grouped_conv_fwd" @@ -40,7 +41,8 @@ static void print_helper_msg() << " 2: Input bf16, Weight bf16, Output bf16\n" << " 3: Input int8, Weight int8, Output int8\n" << " 4: Input fp8, Weight fp8, Output fp8\n" - << " 5: Input bf8, Weight bf8, Output fp8)\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg4: verification (0: no, 1: yes)\n" @@ -118,7 +120,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) auto out_layout, auto in_type, auto wei_type, - auto out_type) { + auto out_type, + auto a_compute_type, + auto b_compute_type) { constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; using InLayout = decltype(in_layout); @@ -129,13 +133,18 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using WeiDataType = decltype(wei_type); using OutDataType = decltype(out_type); + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + bool pass = ck::profiler::profile_grouped_conv_fwd_impl( + OutDataType, + AComputeType, + BComputeType>( do_verification, init_method, do_log, time_kernel, params); return pass ? 0 : 1; @@ -146,57 +155,59 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}); + return profile( + I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}); + return profile( + I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } // NHWGC_GKYXC_NHWGK @@ -204,65 +215,71 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}); + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}); + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } else if(data_type == ConvDataType::F8_F8_F8) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, F8{}, F8{}); } else if(data_type == ConvDataType::BF8_BF8_F8) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}, BF8{}, BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{}); } } From c701071666ce5656c8bd4331979f56fcc497fda6 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 4 Apr 2024 11:01:33 +0200 Subject: [PATCH 2/3] Add Grouped Gemm Multiple D SplitK TwoStage (#1212) * Support A/B/C elementwise ops. * First part of GGEMM multiD splitk two stage. * WIP - changes for debuggin. * tmp save * working version * added bf16@int8 version * fixes * add reviewers sugestions * pre-commited missing files * switched to ifs from elseifs --------- Co-authored-by: Adam Osewski --- ...rouped_gemm_multiple_d_splitk_xdl_fp16.cpp | 394 +++++++ .../device_grouped_gemm_multiple_d_splitk.hpp | 136 +++ .../device/impl/device_elementwise_impl.hpp | 16 +- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 987 ++++++++++++++++++ ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 27 +- .../cpu/reference_gemm_multiple_d.hpp | 175 ++++ .../gpu/grouped_gemm.hpp | 50 +- .../gpu/grouped_gemm/CMakeLists.txt | 2 + ...o_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp | 99 ++ ...wo_stage_f16_f16_f16_mk_kn_mn_instance.cpp | 96 ++ .../profile_grouped_gemm_two_stage_impl.hpp | 366 +++++++ profiler/src/CMakeLists.txt | 1 + .../src/profile_grouped_gemm_two_stage.cpp | 157 +++ 13 files changed, 2490 insertions(+), 16 deletions(-) create mode 100644 example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp create mode 100644 profiler/src/profile_grouped_gemm_two_stage.cpp diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp new file mode 100644 index 0000000000..ecff7b4713 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -0,0 +1,394 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr int NumDMatrices = 2; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector> stride_Ds; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 128; + bool time_kernel = true; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + std::vector p_As; + std::vector p_Bs; + std::vector> p_Ds = {}; + + gemm_descs.reserve(group_count); + p_As.reserve(group_count); + p_Bs.reserve(group_count); + p_Ds.reserve(group_count); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector, NumDMatrices>> d_tensors; + std::vector> c_host_tensors; + std::vector> c_device_result_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_result_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + std::vector> d_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + + auto d0_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + auto d1_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + + std::array, NumDMatrices> d_tens = {d0_tensor, d1_tensor}; + d_tensors.push_back(d_tens); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc + << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + + sizeof(BDataType) * b_tensors[i].GetElementSize() + + sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDMatrices + + sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + } + } + } + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); + + b_tensors_device.emplace_back( + std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); + + c_tensors_device.emplace_back(std::make_unique( + c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); + + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors_device[i].emplace_back(std::make_unique( + d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); + } + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); + } + c_tensors_device[i]->SetZero(); + p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_Ds.push_back( + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + gemm_descs.push_back({problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + problem_size.stride_Ds[i]}); + } + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + gemm.SetKBatchSize(argument, config.k_batch); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultipleD; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + auto karg = argument.gemm_kernel_args_[i].karg_; + auto dev_res_tensor = + Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{})); + + c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(), + c_device_result_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + d_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + + return pass; +} + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc < 11) + { + std::vector Ms{64, 127, 255, 129, 260, 190, 77}; + problem_size.group_count = Ms.size(); + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(Ms[i]); + problem_size.Ns.push_back(252); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + problem_size.stride_Ds.push_back({}); + for(int j = 0; j < NumDMatrices; ++j) + { + problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); + } + } + + std::cout + << "Usage:\n" + << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg10: k_batch (> 0)\n" + << "... setting default values." << std::endl; + } + else + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[10]); + + problem_size.Ms = argToIntArray(argv[4]); + problem_size.Ns = argToIntArray(argv[5]); + problem_size.Ks = argToIntArray(argv[6]); + + problem_size.stride_As = argToIntArray(argv[7]); + problem_size.stride_Bs = argToIntArray(argv[8]); + problem_size.stride_Cs = argToIntArray(argv[9]); + + for(int j = 0; j < NumDMatrices; ++j) + { + problem_size.stride_Ds.push_back(problem_size.stride_Cs); + } + + problem_size.group_count = problem_size.Ms.size(); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp new file mode 100644 index 0000000000..d91eac0730 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Structure representing single GEMM problem arguments. +/// +/// The pointer to the vector of those structures is passed to the GroupedGEMM entry +/// point kernel. +/// +/// @tparam NumDTensor The number of D input tensors. +/// +template +struct GroupedGemmMultipleDKernelArguments +{ + __host__ __device__ + GroupedGemmMultipleDKernelArguments(const void* p_a_grid_, + const void* p_b_grid_, + std::array p_ds_grid_, + void* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_) + : p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{p_ds_grid_}, + p_e_grid{p_e_grid_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideE{StrideE_} + { + } + + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; + + void Print() const + { + std::stringstream str; + for(auto sd : StrideDs) + str << sd << ","; + + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SE:" << StrideE << ", " + << "SDs: {" << str.str() << "}" + << "}" << std::endl; + } +}; + +template +struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm +{ + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// + virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel + /// arguments. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Gets the device kernel argument size. + /// + /// @param[in] p_arg The pointer to the Device op Argument. + /// + /// @return The device kernel argument size. + /// + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp index 37867f1eaa..1a44c3ed9c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,10 +22,12 @@ namespace device { template + index_t NumDim, // The max dim of input tensors + // the tensors descs have to be aligned, such that + // the innermost dim is the contiguous one. + index_t MPerThread, // How many elements per thread to read + typename InScalarPerVectorSeq, // Scalar per vec for each Input + typename OutScalarPerVectorSeq> // Scalar per vec for each Output struct DeviceElementwiseImpl : public DeviceElementwise { @@ -242,13 +244,13 @@ struct DeviceElementwiseImpl static_for<0, NumInput, 1>{}([&](auto I) { if(!IsScalarPerVectorValid( arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) - valid = false; + valid = valid && false; }); static_for<0, NumOutput, 1>{}([&](auto I) { if(!IsScalarPerVectorValid( arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) - valid = false; + valid = valid && false; }); return valid; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp new file mode 100644 index 0000000000..2d60c027bb --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -0,0 +1,987 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template = false> +struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage + : public DeviceGroupedGemmMultipleDSplitK +{ + using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + // TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1 + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using WorkspaceDataType = float; + + // First stage GridwiseGEMM kernel. + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, + BDataType, + AccDataType, + WorkspaceDataType, + ALayout, + BLayout, + ELayout, + AElementwiseOperation, + BElementwiseOperation, + PassThrough, // CElementwiseOperation + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + AK1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeDataType>; + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeElementwiseInputSequence() + { + return generate_sequence_v2( + [&]([[maybe_unused]] auto i) constexpr { + return Number{}; + }, + Number{}); + } + + using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using DsGridPointer = decltype(MakeDsGridPointer()); + using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple{}, DsGridDesc_M_N{})); + using CDDataTypes = decltype(concat_tuple(ck::Tuple{}, DsGridPointer{})); + + using ElementwiseInputSequence = decltype(MakeElementwiseInputSequence()); + + static constexpr index_t ClusterLengthMPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + using GridwiseElementwise = + GridwiseElementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + ElementwiseInputSequence, + ck::Sequence, + true>; + + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using GemmKernelArgument = typename GridwiseGemm::Argument; + + struct GemmTransKernelArg + { + GemmKernelArgument karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArg() = default; + GemmTransKernelArg(GemmKernelArgument&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + + static constexpr index_t DefaultKBatch = 1; + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + cde_element_op, + DefaultKBatch) + { + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t kbatch) + : K_BATCH{kbatch}, + group_count_{0}, + skipped_group_count_{0}, + grid_size_{0}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + p_Ds_{p_Ds} + { + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size"); + } + + gemm_kernel_args_.reserve(group_count_); + elementwise_c_grid_descs_m_n_.reserve(group_count_); + elementwise_d_grid_descs_m_n_.reserve(group_count_); + ds_grid_pointer_.reserve(group_count_); + group_grid_size_.reserve(group_count_); + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M * N * K == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_e = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH); + + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_e); + + DsGridDesc_M_N ds_grid_desc_m_n; + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + p_ds_grid(j) = static_cast(p_Ds[i][j]); + ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N( + M, N, gemm_descs[i].stride_Ds_[j]); + }); + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + group_grid_size_[i] = grid_size_grp; + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + std::array stride_ds; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + stride_ds[j] = gemm_descs[i].stride_Ds_[j]; + }); + stride_Ds_.emplace_back(std::move(stride_ds)); + + // We first set E pointer to actual operation output, but later on + // when workspace will be set, this will be updated to workspace memory. + auto karg = GemmKernelArgument{type_convert(p_As[i]), + type_convert(p_Bs[i]), + type_convert(p_Es[i]), + M, + N, + K, + stride_a, + stride_b, + stride_e, + m_padded, + n_padded, + k_padded, + k0_padded, + K_BATCH}; + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + + elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n); + elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n); + ds_grid_pointer_.push_back(p_ds_grid); + } + // Store a copy of E pointers for elementwise kernel destination + e_ptrs_ = p_Es; + } + + /** + * @brief Set new kbatch value. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + group_grid_size_[i] = grid_size_grp; + karg.KPadded = k_padded; + karg.K0Padded = k0_padded; + karg.k_batch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + +#if DEBUG_LOG + index_t tiles = (block_end - block_start) / K_BATCH; + std::cout << "block_start: " << block_start << "\n" + << "block_end: " << block_end << "\n" + << "tiles: " << tiles << std::endl + << std::endl; + + std::cout << "KPadded: " << karg.KPadded << std::endl + << "K0Padded: " << karg.K0Padded << std::endl + << "KBatch: " << karg.k_batch << std::endl + << "grid_size_: " << karg.KPadded << std::endl; +#endif + } + } + + void UpdateEPointers() + { + // set-up each group E pointer to it's designated workspace memory. + WorkspaceDataType* p_workspace = reinterpret_cast(p_workspace_); + std::size_t offset = 0; + + for(auto& arg : gemm_kernel_args_) + { + arg.karg_.p_c_grid = p_workspace + offset; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + offset += tiles * MPerBlock * NPerBlock; +#if DEBUG_LOG + std::cout << "block_start: " << arg.block_start_ << "\n" + << "block_end: " << arg.block_end_ << "\n" + << "tiles: " << tiles << "\n" + << "offset: " << offset << std::endl; +#endif + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + std::size_t size_bytes{0}; + + for(const auto& arg : gemm_kernel_args_) + { + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + size_bytes += tiles * MPerBlock * NPerBlock * sizeof(WorkspaceDataType); + } + return size_bytes; + } + + std::size_t GetWorkspaceSize(std::size_t group) const + { + const auto& arg = gemm_kernel_args_[group]; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + return tiles * MPerBlock * NPerBlock; + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + index_t grid_size_; + // Pointer to device memory with GEMM kernel arguments. + const void* p_dev_gemm_args_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + std::vector>& p_Ds_; + std::vector> stride_Ds_; + std::vector gemm_kernel_args_; + std::vector group_grid_size_; + + std::vector elementwise_c_grid_descs_m_n_; + std::vector elementwise_d_grid_descs_m_n_; + std::vector ds_grid_pointer_; + std::vector e_ptrs_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] dev_gemm_workspace The pointer to device memory for kernel auxiliary + /// workspace. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + void* dev_gemm_workspace, + const StreamConfig& stream_config = StreamConfig{}) + { + auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = + CheckArgument(arg, stream_config); + + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(dev_gemm_workspace == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + float ave_time = 0; + + if(all_have_main_k_block_loop) + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + else + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + + return ave_time; + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, @see GetWorkSpaceSize and @see + /// SetDeviceKernelArgs, @see SetWorkSpacePointer on arg parameter to properly + /// allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(arg.p_workspace_ == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + + private: + auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const + { + bool all_have_kbatch_gt_one, all_have_main_k_block_loop; + + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1( + arg.gemm_kernel_args_[0].karg_.M, + arg.gemm_kernel_args_[0].karg_.MPadded, + arg.gemm_kernel_args_[0].karg_.K, + arg.gemm_kernel_args_[0].karg_.StrideA, + arg.gemm_kernel_args_[0].karg_.k_batch, + arg.gemm_kernel_args_[0].karg_.K0Padded, + arg.gemm_kernel_args_[0].karg_.KPadded); + + all_have_kbatch_gt_one = arg.K_BATCH > 1; + all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + } + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M, + gemm_arg.MPadded, + gemm_arg.K, + gemm_arg.StrideA, + gemm_arg.k_batch, + gemm_arg.K0Padded, + gemm_arg.KPadded); + + bool not_all_have_main_k_block_loop_same = + all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + bool not_all_have_kbatch_value_same = + all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1); + + if(not_all_have_main_k_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " + << "group [" << i << "], kbatch: " << gemm_arg.k_batch + << ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop); + } + + template + float DispatchKernel(const Argument& arg, + const void* dev_gemm_args, + void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + const auto gemm_kernel = + kernel_grouped_gemm_xdl_splitk; + + const auto elementwise_kernel = kernel_elementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation>; + return LaunchKernel(gemm_kernel, + elementwise_kernel, + arg, + dev_gemm_args, + dev_gemm_workspace, + stream_config); + } + + template + float LaunchKernel(const KernelFunction& gemm_kernel, + const KernelFunction2& elementwise_kernel, + const Argument& arg, + const void* dev_gemm_args, + [[maybe_unused]] void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + float time{0.f}; + + auto preprocess = [&]() { + hip_check_error(hipMemsetAsync( + dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + // GEMM kernel + time = launch_and_time_kernel_with_preprocess( + stream_config, + preprocess, + gemm_kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + PassThrough{}); + + // Elementwise kernels + for(int i = 0; i < arg.group_count_; ++i) + { + time += launch_and_time_kernel( + stream_config, + elementwise_kernel, + dim3(arg.group_grid_size_[i]), + dim3(BlockSize), + 0, + concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + arg.elementwise_d_grid_descs_m_n_[i]), + make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid), + arg.ds_grid_pointer_[i]), + type_convert(arg.e_ptrs_[i]), + Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0), + arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)}, + arg.cde_element_op_); + } + return time; + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { +#if DEBUG_LOG + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; +#endif // DEBUG_LOG + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + + bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); + if(not group_arg_valid) + { +#if DEBUG_LOG + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gemm_arg.Print(); +#endif // DEBUG_LOG + } + supported = supported && group_arg_valid; + } + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpy(p_dev_kernel_args, + arg.gemm_kernel_args_.data(), + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + void SetWorkSpacePointer( + BaseArgument* p_arg, + void* p_workspace, + [[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + p_arg_->UpdateEPointers(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + return SetKBatchSize(*dynamic_cast(p_arg), kbatch); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->gemm_kernel_args_.size() * + sizeof(GemmTransKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index abee2fea53..a33d7d8fb5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,13 +26,19 @@ namespace device { template + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, + typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, + typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count) + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -64,10 +70,16 @@ __global__ void GridwiseGemm::template Run( gemm_desc_ptr[group_id].karg_, static_cast(p_shared), - gemm_desc_ptr[group_id].block_2_ctile_map_); + gemm_desc_ptr[group_id].block_2_ctile_map_, + a_element_op, + b_element_op, + c_element_op); #else ignore = gemm_descs_const; ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -193,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; using KernelArgument = typename GridwiseGemm::Argument; - + using PassThrough = ck::tensor_operation::element_wise::PassThrough; struct GemmTransKernelArg { KernelArgument karg_; @@ -437,7 +449,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// assumption: every D matrix has the same layout and the same datatype +template +struct ReferenceGemmMultipleD : public device::BaseOperator +{ + using DDataType = remove_cvref_t>; + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const std::array, DsDataType::Size()>& ds_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + ds_m_n_{ds_m_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const std::array, DsDataType::Size()>& ds_m_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmMultipleD::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + ComputeTypeA v_a = 0; + ComputeTypeB v_b = 0; + + for(int k = 0; k < K; ++k) + { + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + CDataType v_c = 0; + + if constexpr(DsDataType::Size() == 0) + { + arg.cde_element_op_(v_c, v_acc); + } + else if constexpr(DsDataType::Size() == 1) + { + arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n)); + } + else if constexpr(DsDataType::Size() == 2) + { + arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n), arg.ds_m_n_[1](m, n)); + } + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const std::array, DsDataType::Size()>& ds_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_m_k, b_k_n, ds_m_n, c_m_n, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmMultipleD" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index 056e906c27..d06a579811 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -146,6 +146,32 @@ void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances); + template > op_ptrs; +#if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v) { @@ -190,6 +217,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -210,8 +239,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -228,6 +259,19 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances( + op_ptrs); + } + } +#endif return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index 2625e6cbe8..5a50eca107 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -10,4 +10,6 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp + device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..8d3baf19ee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future +// a[m, k] * b[k, n] = e[m, n] +using device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_generic_instances = + std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1> + // clang-format on + >; + +void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_generic_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..d384842343 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; +using Empty_Tuple = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future +// a[m, k] * b[k, n] = e[m, n] +using device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_generic_instances = + std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1> + // clang-format on + >; + +void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_generic_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp new file mode 100644 index 0000000000..41dcabbfcf --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp @@ -0,0 +1,366 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_two_stage_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_host_results; + std::vector> c_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + c_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); +#if DEBUG_LOG + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; +#endif // DEBUG_LOG + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + + gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + auto p_ds = std::vector>{}; + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_results[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_c, + gemm_descs, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + using DeviceOpSplitK = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + // skip non-splitk grouped_gemm + if(dynamic_cast(gemm_ptr.get()) == nullptr) + { + continue; + } + + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; + + if(kbatch > 0) + { + kbatch_list = {kbatch}; + } + + for(std::size_t j = 0; j < kbatch_list.size(); j++) + { + + auto kbatch_curr = kbatch_list[j]; + dynamic_cast(gemm_ptr.get()) + ->SetKBatchSize(argument_ptr.get(), kbatch_curr); + + DeviceMem gemm_arg_dev_mem(dynamic_cast(gemm_ptr.get()) + ->GetDeviceKernelArgSize(argument_ptr.get())); + dynamic_cast(gemm_ptr.get()) + ->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer()); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + gemm_desc_workspace.SetZero(); + for(std::size_t i = 0; i < gemm_descs.size(); i++) + c_device_buf[i]->SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + if(do_verification) + { + bool instance_pass = true; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + if(std::is_same_v && kbatch_curr > 1) + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i], + "Error: Incorrect results!", + 0.06); + } + else + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i]); + } + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_results[i].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + if(time_kernel) + { + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch + << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index cb6ffbec6c..e8992070b5 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -40,6 +40,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) endif() list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) diff --git a/profiler/src/profile_grouped_gemm_two_stage.cpp b/profiler/src/profile_grouped_gemm_two_stage.cpp new file mode 100644 index 0000000000..17daf1e80c --- /dev/null +++ b/profiler/src/profile_grouped_gemm_two_stage.cpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_two_stage_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 +}; + +enum struct GemmDataType +{ + F16_F16_F16, // 0 + BF16_INT8_BF16 // 1 +}; + +#define OP_NAME "grouped_gemm_two_stage" +#define OP_DESC "Grouped GEMM TwoStage" + +namespace { + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int profile_grouped_gemm_two_stage(int argc, char* argv[]) +{ + if(argc < 14) + { + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: fp16; 1: bf16@int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0=n0, 1=yes)\n" + << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg15: kbatch value (default 1)\n" + << "optional:\n" + << "arg16: number of warm-up cycles (default 1)\n" + << "arg17: number of iterations (default 10)\n" + << std::endl; + + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + auto StrideAs = argToIntArray(argv[11]); + auto StrideBs = argToIntArray(argv[12]); + auto StrideCs = argToIntArray(argv[13]); + const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; + + const int DefaultStrideA = Ks[0]; + const int DefaultStrideB = Ns[0]; + const int DefaultStrideC = Ns[0]; + + for(size_t i = 0; i < Ms.size(); ++i) + { + StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i]; + StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i]; + StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i]; + } + + int n_warmup = 1; + int n_iter = 10; + if(argc == 17) + { + n_warmup = std::stoi(argv[16]); + n_iter = std::stoi(argv[17]); + } + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_two_stage_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_two_stage_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + return 0; +} + +} // anonymous namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_two_stage); From 7e5c81fed2737312f960cd41fe9afbc02669ce27 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 4 Apr 2024 11:33:29 -0700 Subject: [PATCH 3/3] fix the latest errors with staging compiler (#1229) --- test/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 720ab468ea..bbb75c49e8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -140,6 +140,7 @@ function(add_gtest_executable TEST_NAME) set(result ${result} PARENT_SCOPE) endfunction() +add_compile_options(-Wno-c++20-extensions) add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) add_subdirectory(conv_util)