From 5ccd4c679f491fb641da5cf31c2abe63ac5d4701 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Fri, 25 Jul 2025 10:34:31 +0200 Subject: [PATCH] Add v3 support for Groupd fwd conv+bias+clamp & ckProfiler (#2463) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add logging to IsSupported. * Less casting in AddClamp * Conv+bias+clamp instances & profiler BF16 * Fix 3D instances & run just 1x for verification. * :Run just once for verification conv fwd. * ckProfiler conv fwd clampwq * Remove exec bit & formatting * Add support for MultiD for grouped conv fwd v3. * Enable 2Lds. * clean * align instances * align instances * profiler fixes * Fixes * fix * fix --------- Co-authored-by: Adam Osewski Co-authored-by: Bartłomiej Kocot [ROCm/composable_kernel commit: c8eb2f995cea2d8dbbe2e286ff0fe99f75efb227] --- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 673 ++++++++++++------ .../element/binary_element_wise_operation.hpp | 8 +- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 327 ++++++--- profiler/src/CMakeLists.txt | 6 + .../profile_grouped_conv_fwd_bias_clamp.cpp | 191 +++++ .../src/profile_grouped_conv_fwd_clamp.cpp | 194 +++++ 6 files changed, 1098 insertions(+), 301 deletions(-) create mode 100644 profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp create mode 100644 profiler/src/profile_grouped_conv_fwd_clamp.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 48424c16b9..e30caf3aac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -21,7 +21,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" @@ -61,10 +61,11 @@ namespace { * */ template {}( + [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; }); + + const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = + const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const long_index_t a_n_offset = @@ -101,29 +110,41 @@ __global__ void __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + GridwiseGemm::template Run( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_m_n; + ignore = c_grid_desc_m_n; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; #endif // end of if (defined(__gfx9__)) } template {}( + [&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; }); + + const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = + const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const long_index_t a_n_offset = @@ -163,22 +193,33 @@ __global__ void __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_m_n; + ignore = c_grid_desc_m_n; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; #endif // end of if (defined(__gfx9__)) } @@ -277,10 +318,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; static constexpr bool isMultiD = DsDataType::Size() > 0; - static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; + static constexpr bool isMultiABD = isMultiA && isMultiB && isMultiD; static constexpr bool DoElementwiseBeforeCShuffle = - !isMultiABD && is_same_v && + !isMultiD && is_same_v && !is_same_v; static constexpr index_t NumATensor = GetNumABTensors(); @@ -294,12 +335,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; + // Generate vector size for C & Ds + using CDEBlockTransferScalarPerVectors = + typename uniform_sequence_gen::type; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + using ComputePtrOffset = ComputePtrOffsetOfStridedBatch; + static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; @@ -396,30 +444,81 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return out_gemmm_gemmn_desc; } + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + // desc for problem definition constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using EGridDesc_M_N = remove_cvref_t(dummy_conv_to_gemm_transformer))>; - -#define GridwiseGemmV3TemplateParams \ - tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ - tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, \ - EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, \ - MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ - ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ - ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ - ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ - BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ - BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ - BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ - BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ - CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ - AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle + using DsGridDesc_M_N = + remove_cvref_t; // Use appropriate gridwise gemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3; + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + DsLayout, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeDataType, + BComputeDataType, + ADataType, + BDataType, + DoElementwiseBeforeCShuffle>; + + // #undef GridwiseGemmV3TemplateParams using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -493,37 +592,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 I0, I1>; - static auto - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) - { - const index_t M = e_grid_desc_m_n.GetLength(I0); - const index_t N = e_grid_desc_m_n.GetLength(I1); - return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n, GridwiseGemm::CalculateMBlock(M), GridwiseGemm::CalculateNBlock(N)); - } - // desc for blockwise copy using AGridDesc_AK0_M_AK1 = remove_cvref_t( dummy_conv_to_gemm_transformer))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t( dummy_conv_to_gemm_transformer))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // Argument struct Argument : public BaseArgument { Argument(const void* p_as, const void* p_bs, - const std::array&, + const std::array& p_ds, void* p_e, const std::array& a_g_n_c_wis_lengths, const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>&, - const std::array, NumDTensor>&, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, const std::array& e_g_n_k_wos_lengths, const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, @@ -535,6 +624,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const CDEElementwiseOperation& cde_element_op) : p_a_grid_{}, p_b_grid_{}, + p_ds_grid_{p_ds}, p_e_grid_{static_cast(p_e)}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( @@ -542,6 +632,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, @@ -561,13 +653,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 input_left_pads_, input_right_pads_}, conv_N_per_block_{conv_to_gemm_transformer_.N_}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_ak0_m_ak1_{ MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, b_grid_desc_bk0_n_bk1_{ MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, - e_grid_desc_m_n_{ - DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, @@ -583,12 +675,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 p_a_grid_ = static_cast(p_as); p_b_grid_ = static_cast(p_bs); + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // D batch stride + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + ds_g_n_k_wos_strides_[i], + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0]; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); - if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -610,14 +723,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 e_in_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( e_g_n_k_wos_lengths, e_g_n_k_wos_strides); - elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ - b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; e_out_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( e_g_n_k_wos_lengths, e_g_n_k_wos_strides); elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; } @@ -680,6 +793,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; } @@ -687,6 +802,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // pointers (tuple if multi AB, pointer if no) const ADataType* p_a_grid_; const BDataType* p_b_grid_; + const std::array p_ds_grid_; EDataType* p_e_grid_; // for checking IsSupportedArgument() @@ -694,6 +810,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 std::array a_g_n_c_wis_strides_; std::array b_g_k_c_xs_lengths_; std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; std::array e_g_n_k_wos_lengths_; std::array e_g_n_k_wos_strides_; std::array conv_filter_strides_; @@ -705,18 +823,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 index_t num_group_; ConvToGemmFwdTransformer conv_to_gemm_transformer_; - index_t conv_N_per_block_; // tensor descriptors for block/thread-wise copy + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - EGridDesc_M_N e_grid_desc_m_n_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + ComputePtrOffset compute_ptr_offset_of_groups_; + ComputePtrOffset compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -759,6 +877,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; index_t gdx, gdy, gdz; + // TODO: Do we want to support kbatch ?? std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); @@ -784,20 +903,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 sizeof(EDataType); } - typename GridwiseGemm::Argument gemm_arg{p_a_grid, - p_b_grid, - p_e_grid, - GemmM, - GemmN, - GemmK, - I0, - I0, - I0, - I1, - false, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_}; + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, + p_b_grid, + arg.p_ds_grid_, + p_e_grid, + GemmM, + GemmN, + GemmK, + // No need to set strides, we pass descs to kernel + I0, + I0, + {}, + I0, + I1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) @@ -827,24 +949,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 gemm_arg_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, arg.compute_ptr_offset_of_groups_, arg.compute_ptr_offset_of_n_); } else { - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - gemm_arg, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_); + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); } }; @@ -854,15 +977,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } // Tail number could be One to Seven @@ -870,30 +994,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::One>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Full>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } @@ -903,10 +1029,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -921,10 +1048,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -939,10 +1067,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -957,10 +1086,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -975,10 +1105,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -993,10 +1124,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1012,10 +1144,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1026,10 +1159,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1041,48 +1175,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } } } + // has_main_k_block_loop else { // Tail number always 1 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } } @@ -1095,6 +1233,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 float avg_time = 0.f; if constexpr(!isMultiABD) { + // Transpose to NGHWC layotu if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -1147,6 +1286,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 avg_time += RunGemm(arg, stream_config); + // Transpose result back to NGCHW if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -1205,6 +1345,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if constexpr(isMultiABD) { return false; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The MultiABD is not supported!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } } // check device @@ -1213,12 +1359,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // FIXME: re-enable fp64 when SWDEV-335738 is fixed if constexpr(!(is_same_v || is_same_v)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "On gfx908 the accumulation data type must be one of fp32 or int32!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } if(!ck::is_xdl_supported()) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Current device does not support xdl instructions!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } @@ -1236,6 +1395,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The input paramters do not align with specialization " + "Filter1x1Stride1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1252,6 +1418,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "The input paramters do not align with specialization Filter1x1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1268,11 +1441,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[A Layout] The number of input channels is not a multiple of " + "ABlockTransferSrcScalarPerVector!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } @@ -1288,11 +1474,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[B Layout] The number of input channels is not a multiple of " + "BBlockTransferSrcScalarPerVector!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } @@ -1301,11 +1500,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The G * C is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The G * K is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1316,11 +1529,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The input_spatial_acum is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The output_spatial_acum is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1340,6 +1567,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] One of the transposed vectors is exceeding 2GB " + "memory size!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1354,17 +1588,37 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[E Layout] The K is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } // Gridwise gemm v3 doesn't verify descriptors size if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB()) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "[conv_to_gemm_transformer_] One of the descriptors is bigger than 2GB!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } @@ -1374,8 +1628,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{ - nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, I1 /*KBatch*/}; + typename GridwiseGemm::Argument gemm_arg{nullptr, + nullptr, + {}, + nullptr, + GemmM, + GemmN, + GemmK, + I0, + I0, + {}, + I0, + I1 /*KBatch*/, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; return GridwiseGemm::CheckValidity(gemm_arg); } diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 34c76b89e4..d86f01e255 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -379,10 +379,10 @@ struct AddClamp __host__ __device__ constexpr void operator()(half_t& y, const half_t& x0, const half_t& x1) const { - const half_t a = x0 + x1; - y = a > type_convert(floor_) - ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) - : type_convert(floor_); + const half_t floor = type_convert(floor_); + const half_t ceil = type_convert(ceil_); + const half_t a = x0 + x1; + y = a > floor ? (a < ceil ? a : ceil) : floor; }; template <> diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index c8dbd81b73..a3694e3767 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -143,7 +143,8 @@ template + typename LDSTypeB = BDataType, + bool DoElementwiseBeforeCShuffle = false> struct GridwiseGemmMultiD_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; @@ -466,6 +467,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 { return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); } + else + { + static_assert(false, + "The layout configuration is not supported! " + "Only support Row & Col major."); + } }(); // pad M and N @@ -538,8 +545,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 Number{}); } - using DsGridDesc_M_N = remove_cvref_t; - struct Problem { __host__ __device__ Problem() = default; @@ -1245,11 +1250,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1273,11 +1278,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, TailNumber TailNum = TailNumber::Odd> - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1288,17 +1293,62 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const CGridDesc_M_N& c_grid_desc_m_n) + { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1515,43 +1565,63 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); @@ -1601,7 +1671,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 Tuple, decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, + conditional_t, Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence // support arbitray type Sequence<1, @@ -1625,7 +1697,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = @@ -1698,12 +1770,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 template - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared_0, + void* __restrict__ p_shared_1, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1729,12 +1801,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, TailNumber TailNum = TailNumber::Odd> - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared_0, + void* __restrict__ p_shared_1, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1745,8 +1817,53 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + Run_2Lds(p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } + + template + __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared_0, + void* __restrict__ p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const CGridDesc_M_N& c_grid_desc_m_n) + { const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -1982,43 +2099,63 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); @@ -2068,7 +2205,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 Tuple, decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, + conditional_t, Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence // support arbitray type Sequence<1, @@ -2092,7 +2231,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index e27fda05e4..4700a34e9d 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -94,6 +94,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12 list(APPEND PROFILER_OPS profile_batched_gemm.cpp) list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) endif() @@ -197,6 +199,10 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR diff --git a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp new file mode 100644 index 0000000000..34b3df1c65 --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_bias_clamp" +#define OP_DESC "Grouped Convolution Forward+Bias+Clamp" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_bias_clamp); diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp new file mode 100644 index 0000000000..600f91744a --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_clamp" +#define OP_DESC "Grouped Convolution Forward+Clamp" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_clamp(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = + ck::profiler::profile_grouped_conv_fwd_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_clamp);