From 391e06e070ebb4e4b56ec37b581c7d0baea9fa7a Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Sun, 25 Jan 2026 19:53:52 +0000 Subject: [PATCH] tmp save between remotes --- CMakeLists.txt | 2 +- include/ck/library/utility/check_err.hpp | 16 +- include/ck/stream_config.hpp | 4 +- .../blockwise_gemm_pipeline_xdlops_v1.hpp | 11 + ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 2 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 2 + ...nv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp | 1983 +++++++++++++++++ .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 29 +- ..._grouped_conv_bwd_data_xdl_v3_instance.hpp | 4 +- .../gpu/grouped_convolution_backward_data.hpp | 298 +-- .../grouped_conv2d_bwd_data/CMakeLists.txt | 1 + .../profile_grouped_conv_bwd_data_impl.hpp | 3 + profiler/src/CMakeLists.txt | 326 +-- 13 files changed, 2347 insertions(+), 334 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 54464d6809..f7d8610801 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -741,7 +741,7 @@ if (NOT MIOPEN_REQ_LIBS_ONLY) endif() if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) - add_subdirectory(codegen) + #add_subdirectory(codegen) endif() #Create an interface target for the include only files and call it "composablekernels" diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 677361b579..32d5fc7026 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -190,7 +190,7 @@ check_err(const Range& out, if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; - if(err_count < 5) + if(err_count < 40) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -244,7 +244,7 @@ check_err(const Range& out, if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; - if(err_count < 5) + if(err_count < 40) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -298,7 +298,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 40) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -350,7 +350,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 40) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -409,7 +409,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 40) { std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -462,7 +462,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 40) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -511,7 +511,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 40) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -560,7 +560,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 40) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index ea1c15b1aa..ce89b7c9e3 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -10,8 +10,8 @@ struct StreamConfig hipStream_t stream_id_ = nullptr; bool time_kernel_ = false; int log_level_ = 0; - int cold_niters_ = 5; - int nrepeat_ = 50; + int cold_niters_ = 0; + int nrepeat_ = 1; bool flush_cache = false; int rotating_count = 1; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index 5604a31091..b6b50a3364 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -187,6 +187,7 @@ struct BlockwiseGemmXdlops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -212,6 +213,7 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k) { static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -313,6 +316,14 @@ struct BlockwiseGemmXdlops_pipeline_v1()(ik) = b_thread_buf[Number{}]; + + // if(threadIdx.x == 0) { + // printf("a: %f b: %f\n", + // static_cast(a_thread_buf[Number{}]), + // static_cast(b_thread_buf[Number{}])); + // } }); using mfma_input_type = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index bbf62d5fbe..0f16b79221 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -506,7 +506,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}, 1, 1)); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemmCTranspose::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}, 1, 1)); using Block2ETileMap = typename GridwiseGemmCTranspose::Block2CTileMap; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index b324845c3e..aa557590da 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1186,8 +1186,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto clear_workspace = [&]() { if(arg.bwd_needs_zero_out && gemm_set_id == 0) { + printf("pre memset\n"); hip_check_error(hipMemsetAsync( p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + printf("post memset\n"); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..09aa04e7e0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,1983 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const std::array gemm_kernel_args, + const index_t gemms_count, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) +{ +#if defined(__gfx9__) + // offset base pointer for each work-group + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); + + const long_index_t a_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + const long_index_t a_n_offset = + CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + [[maybe_unused]] const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; + + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; + using DsGridPointer = typename GridwiseGemm::DsGridPointer; + DsGridPointer p_ds_grid_grp{}; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + e_n_offset + ds_batch_offset[i]; + }); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + [[maybe_unused]] const auto num_k_per_block = + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_.GetLength(Number<0>{}) / KBatch; + + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + p_ds_grid_grp, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + gemm_kernel_args[group_id].block_2_ctile_map_, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_m_n_, + gemm_kernel_args[group_id].e_grid_desc_m_n_, + KBatch, + k_idx); + } else { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + p_ds_grid_grp, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + gemm_kernel_args[group_id].block_2_ctile_map_, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_m_n_, + gemm_kernel_args[group_id].e_grid_desc_m_n_, + KBatch, + k_idx); + } +#else + ignore = karg; + ignore = gemm_kernel_args; + ignore = gemms_count; + ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_n; + ignore = KBatch; + +#endif // End of if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +} +} // namespace + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + static_assert(std::is_same_v, "A not NGHWC"); + static_assert(std::is_same_v, "B not GKYXC"); + static_assert(std::is_same_v, "C not NGHWK"); + + // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this + // implementation we can avoid copy data to workspace before kernel launch since number of + // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then + // we run this kernel in the loop. + static constexpr index_t MaxGroupedGemmGroupsNum = + ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0 + ? 1 + : 32; + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + // Note: the values in CShuffleBlockTransferScalarPerVector sequence must be all the same. + // This is a limitation of the thread transfer implementation (v7r3) + // It should be fixed later on + static constexpr index_t CShuffleBlockTransferScalarPerVector_NPerBlock = + CShuffleBlockTransferScalarPerVector{}[I0]; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; + static constexpr bool IsSplitKSupported = + (CShuffleBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) && + std::is_same_v, element_wise::PassThrough>; + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + static constexpr bool isATensorColMajor = + (ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) && + (ABlockTransferSrcVectorDim == 1) && + (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool NeedTransposeKernel = false; + // (isATensorColMajor == false) && (is_NGCHW_NGKHW() || + // is_NGCDHW_NGKDHW()); + + static constexpr bool CTranspose = false; + // (NeedTransposeKernel == false) && (is_same_v || + // is_same_v); + + using ALayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGK, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGK, + ALayout>>; + using BLayoutAfterTranspose = std::conditional_t< + is_NGCHW_GKCYX_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::GKYXC, + std::conditional_t() && + NeedTransposeKernel, + tensor_layout::convolution::GKZYXC, + BLayout>>; + using ELayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGC, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGC, + ELayout>>; + + static_assert(std::is_same_v, "Aafter not NGHWC"); + static_assert(std::is_same_v, "Bafter not GKYXC"); + static_assert(std::is_same_v, "Cafter not NGHWK"); + + using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + + // Dummy function just used to create an alias to Grid Descriptors + static auto + GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) + { + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); + + const auto ds_grid_desc_m_n = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + using ConvToGemmBwdDataTransformD = + TransformConvBwdDataToGemm_v1; + return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N(); + }, + Number{}); + + const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N(); + + if constexpr(CTranspose) + { + return make_tuple( + b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + else + { + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + } + + static constexpr index_t ABlockTransferSrcScalarPerVectorAligned = + ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8 + ? 4 / sizeof(ADataType) + : ABlockTransferSrcScalarPerVector; + static constexpr index_t BBlockTransferSrcScalarPerVectorAligned = + BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8 + ? 4 / sizeof(BDataType) + : BBlockTransferSrcScalarPerVector; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeType, + BComputeType, + ADataType, + BDataType, + false, + DirectLoad>; + +// #define GridwiseGemmCTransposeTemplateParameters \ +// ALayout, BLayout, DsLayout, ELayout, Tuple, Tuple, AccDataType, \ +// CShuffleDataType, DsDataType, EDataType, BElementwiseOp, AElementwiseOp, CDEElementwiseOp, \ +// GemmSpec, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerWmma, MPerWmma, \ +// NRepeat, MRepeat, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ +// BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ +// BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ +// BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ +// ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ +// ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ +// ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ +// ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \ +// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ +// CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \ +// AComputeType, false, false + + using GridwiseGemmCTranspose = GridwiseGemm; + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + + // Note: the dummy function is used just to create the alias + constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); + + using Block2ETileMap = typename GridwiseGemmCTranspose::Block2CTileMapDefault; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + + struct GemmArgs + { + GemmArgs() = default; + GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + DsGridDesc_M_N + ds_grid_desc_m_n, + EGridDesc_M_N + e_grid_desc_m_n, + GroupedGemmBlock2ETileMap block_2_ctile_map, + index_t BlockStart, + index_t BlockEnd, + bool HasMainKBlockLoop) + : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1), + b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1), + + ds_grid_desc_m_n_( + ds_grid_desc_m_n), + + e_grid_desc_m_n_( + e_grid_desc_m_n), + block_2_ctile_map_(block_2_ctile_map), + BlockStart_(BlockStart), + BlockEnd_(BlockEnd), + HasMainKBlockLoop_(HasMainKBlockLoop) + + { + } + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_M_N + ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t BlockStart_, BlockEnd_; + bool HasMainKBlockLoop_; + }; + // block-to-e-tile map for elementwise kernels + using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; + using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; + static constexpr index_t ClusterLengthMPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + static constexpr index_t TransposeTransferInScalarPerVectorAligned = + std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferInScalarPerVector); + static constexpr index_t TransposeTransferOutScalarPerVectorAligned = + std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferOutScalarPerVector); + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock; + + using GridwiseElementwiseInputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapInOutElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + MPerBlock, + NPerBlock / ClusterLengthNPerBlock, + MPerBlock / ClusterLengthMPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapWeiElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence<1>, + Sequence, + I0, + I1>; + + using GridwiseElementwiseOutputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapInOutElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + MPerBlock, + NPerBlock / ClusterLengthNPerBlock, + MPerBlock / ClusterLengthMPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I0, + I1>; + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + ck::index_t split_k = 1) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + bool image_covered_dilation = true; + bool image_covered_strides = true; + for(index_t d = 0; d < NDimSpatial; d++) + { + // If dilation and stride is not equal we will have some empty places + image_covered_dilation &= + conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1; + // If stride is larger than windows size then we will have some empty places + image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3]; + } + bool if_d_is_output_mem = false; + const void* out_mem_void = static_cast(p_e); + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(p_ds[i] == out_mem_void) + { + if_d_is_output_mem = true; + } + }); + + bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides; + + // Temporary workaround untill prove/fix above conditions. + bwd_needs_zero_out = !if_d_is_output_mem; + e_space_size_bytes = + ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(EDataType); + + std::array a_g_n_k_wos_strides_transposed = + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides) + : a_g_n_k_wos_strides; + std::array b_g_k_c_xs_strides_transposed = + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides) + : b_g_k_c_xs_strides; + std::array e_g_n_c_wis_strides_transposed = + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides) + : e_g_n_c_wis_strides; + + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + index_t grid_size = 0; + // Allocate place for sets of gemms + gemm_kernel_args_.resize( + math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum)); + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + ConvToGemmBwdDataTransform conv_to_gemm_transform_{ + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides_transposed, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides_transposed, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes, + k_batch_}; + + conv_N_per_block_ = conv_to_gemm_transform_.N_; + + const auto a_grid_desc_ak0_m_ak1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + else + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + }(); + + const auto b_grid_desc_bk0_n_bk1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + else + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + }(); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + using ConvToGemmBwdDataTransformD = + TransformConvBwdDataToGemm_v1; + ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{ + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides_transposed, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides_transposed, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N(); + }); + + const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N(); + + // desc for problem definition + const auto a_grid_desc_m_k = + transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = + transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + const index_t grid_size_grp = + std::get<0>(GridwiseGemmCTranspose::CalculateGridSize( + e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1), 1)); + const index_t BlockStart = grid_size; + const index_t BlockEnd = grid_size + grid_size_grp; + + grid_size += grid_size_grp; + + const auto block_2_etile_map = GroupedGemmBlock2ETileMap( + Block2ETileMap( + e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1), 4), + BlockStart); + + // const index_t GemmM = a_grid_desc_m_k.GetLength(I0); + // const index_t GemmN = b_grid_desc_n_k.GetLength(I0); + const index_t GemmK = a_grid_desc_m_k.GetLength(I1); + + //onst auto MBlock = GridwiseGemmCTranspose::CalculateMBlock(GemmM); + //onst auto NBlock = GridwiseGemmCTranspose::CalculateNBlock(GemmN); + + index_t k_grain = split_k * KPerBlock; + index_t K_split = (GemmK + k_grain - 1) / k_grain * KPerBlock; + + const bool HasMainKBlockLoop = + GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(K_split); + + gemm_kernel_args_[gemms_count_ / + MaxGroupedGemmGroupsNum][gemms_count_ % + MaxGroupedGemmGroupsNum] = + GemmArgs{a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + BlockStart, + BlockEnd, + HasMainKBlockLoop}; + gemms_count_++; + if(gemms_count_ % MaxGroupedGemmGroupsNum == 0) + { + gemms_grid_size_.push_back(grid_size); + grid_size = 0; + } + } + } + } + gemm_kernel_args_.resize( + math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum)); + gemms_grid_size_.push_back(grid_size); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = + e_g_n_c_wis_strides_transposed[1] * conv_N_per_block_; + + num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_; + + if constexpr(NeedTransposeKernel) + { + // Use not modified base strides + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapInOutElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapWeiElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapInOutElementwise{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + + compute_ptr_offset_of_workspace_n_.BatchStrideA_ = + a_g_n_k_wos_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_workspace_n_.BatchStrideE_ = + e_g_n_c_wis_strides[1] * conv_N_per_block_; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t a_acum = ck::accumulate_n( + a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t b_acum = ck::accumulate_n( + b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t e_accum = ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + return sizeof(EDataType) * e_accum; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++) + { + std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i] + << std::endl; + + std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i] + << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_m_n_container_[i][j] << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_m_n_container_[i] << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + index_t conv_N_per_block_; + std::vector a_grid_desc_m_k_container_; + std::vector b_grid_desc_n_k_container_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + // std::vector + // ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; + // std::vector + // e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // block-to-e-tile map elementwise kernels + Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_e_; + Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_; + + NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + GKCYXTransposeDescType b_in_transpose_desc_; + GKYXCTransposeDescType b_out_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_workspace_n_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + std::array a_g_n_k_wos_lengths_; + std::array b_g_k_c_xs_lengths_; + std::array e_g_n_c_wis_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + + const index_t k_batch_; + index_t num_workgroups_per_Conv_N_; + std::vector gemms_grid_size_; + index_t gemms_count_ = 0; + std::vector> gemm_kernel_args_; + + bool bwd_needs_zero_out; + long_index_t e_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + template + float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + const index_t gdy = arg.num_group_; + const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_; + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + if constexpr(NeedTransposeKernel) + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + } + } + // Create dummy Ds strides because they are not used in convolution + // since we pass the grid descriptor to gridwise gemm + std::array StrideDs_dummy; + static_for<0, NumDTensor, 1>{}([&](auto i) { StrideDs_dummy[i] = I0; }); + // TODO: fix this, it's not nice to go back and forth + std::array p_ds; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds[i] = static_cast(arg.p_ds_grid_[i]); }); + + for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size(); + gemm_set_id++) + { + const index_t GemmM = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I1); + typename GridwiseGemmCTranspose::Argument gemm_arg{ + p_a_grid, + p_b_grid, + p_ds, + p_e_grid, + GemmM, + GemmN, + GemmK, + I0, + I0, + StrideDs_dummy, + I0, + arg.k_batch_, + CTranspose ? arg.b_element_op_ : arg.a_element_op_, + CTranspose ? arg.a_element_op_ : arg.b_element_op_, + arg.cde_element_op_}; + if(!GridwiseGemmCTranspose::CheckValidity(gemm_arg)) + { + throw std::runtime_error("wrong! device_op has invalid setting"); + } + const index_t gdx = arg.gemms_grid_size_[gemm_set_id]; + + const index_t gemms_count_for_set = + gemm_set_id == arg.gemm_kernel_args_.size() - 1 + ? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id + : MaxGroupedGemmGroupsNum; + + const std::array& gemm_kernel_args = + arg.gemm_kernel_args_[gemm_set_id]; + + const auto clear_workspace = [&]() { + if(arg.bwd_needs_zero_out && gemm_set_id == 0) + { + // printf("zeroying workspace\n"); + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + } + }; + + bool has_loop_in_all_gemm = true; + bool no_loop_in_all_gemm = true; + for(auto i = 0; i < gemms_count_for_set; i++) + { + has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_; + no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_; + } + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto no_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop_.value; + constexpr bool no_main_loop = no_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemmCTranspose, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, + MaxGroupedGemmGroupsNum, + GemmArgs, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop, + CTranspose>; + + return launch_and_time_kernel_with_preprocess(stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + gemm_kernel_args, + gemms_count_for_set, + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + arg.k_batch_); + }; + if(has_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(no_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + // Transpose from NGKHW to NHWGK + if constexpr(NeedTransposeKernel) + { + EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync(p_e_in_grid, + 0, + arg.GetWorkspaceETensorSizeBytes(), + stream_config.stream_id_)); + }; + + const index_t a_grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_) * + arg.num_workgroups_per_Conv_N_; + const index_t b_grid_size = + (is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_) + : 0; // Dont run transpose B if not needed + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + + auto kernel_transpose = + kernel_elementwise_batched_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapInOutElementwise, + Block2TileMapWeiElementwise, + element_wise::PassThrough, + I1, + I1, + I1, + I1>; + + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel_transpose, + dim3(a_grid_size + b_grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size, + arg.num_workgroups_per_Conv_N_, + I1, // B is not splited per N + std::array{ + static_cast(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)}, + std::array{0}, + std::array{ + static_cast(arg.compute_ptr_offset_of_n_.BatchStrideA_)}, + std::array{0}); + } + if(arg.k_batch_ > 1) + { + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm(arg, stream_config); + } + } + else + { + ave_time += RunMultiDGemm(arg, stream_config); + } + + arg.Print(); + + // Transpose from NHWGC to NGCHW + if constexpr(NeedTransposeKernel) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_) * + arg.num_workgroups_per_Conv_N_; + + const EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + EDataType* p_e_out_grid = arg.p_e_grid_; + + auto kernel_transpose = + kernel_batched_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapInOutElementwise, + element_wise::PassThrough, + I1, + I1>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}, + arg.num_workgroups_per_Conv_N_, + std::array{ + static_cast(arg.compute_ptr_offset_of_n_.BatchStrideE_)}, + std::array{static_cast( + arg.compute_ptr_offset_of_workspace_n_.BatchStrideE_)}); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + // if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + // { + // if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + // { + // std::cout << "This configuration is not supported!" << " In " << __FILE__ << ":" + // << __LINE__ << ", in function: " << __func__ << std::endl; + // } + // return false; + // } + + // check device + if constexpr(DirectLoad) + { + if(get_device_name() != "gfx950") + { + return false; + } + } + + if constexpr(!IsSplitKSupported) + { + if(arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + if(ck::is_gfx11_supported() && arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + + const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ConvBwdDataSpecialization is unsupported!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || NeedTransposeKernel) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else if(is_same_v || + is_same_v) + { + static_assert(NeedTransposeKernel == false); + + if constexpr(ABlockTransferSrcScalarPerVector != 1) + { + if(ABlockTransferSrcVectorDim != 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "CTranspose is not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0) + { + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "CTranspose is not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + } + + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported B Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(CTranspose == false) + { + // vector load D matrix from global memory + if(!(ConvC % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + ds_valid = false; + } + } + else + { + if(input_spatial_acum % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "input_spatial_acum / " + "CShuffleBlockTransferScalarPerVector_NPerBlock is wrong!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + ds_valid = false; + } + } + } + else + { + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ds_valid is false!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + ds_valid = false; + } + }); + + if(!ds_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ds_valid is false!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(CTranspose == false) + { + // vector store C matrix into global memory + if(!(ConvC % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(input_spatial_acum % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "input_spatial_acum / " + "ChuffleBlockTransferScalarPerVector_NPerBlock is wrong!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + if constexpr(NeedTransposeKernel) + { + if((ConvG * ConvC) % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + if((ConvG * ConvK) % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + const index_t a_spatial_acum = ck::accumulate_n( + arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t e_spatial_acum = ck::accumulate_n( + arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + if(a_spatial_acum % TransposeTransferInScalarPerVectorAligned != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "a_spatial_acum % TransposeTransferInScalarPerVectorAligned is wrong!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + + if(e_spatial_acum % TransposeTransferOutScalarPerVectorAligned != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "e_spatial_acum % TransposeTransferOutScalarPerVectorAligned is wrong!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Warning: Workspace for " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + } + + // Check gridwise gemm validity + // Create dummy values for Ds pointers and strides + std::array p_ds_grid_dummy; + std::array StrideDs_dummy; + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_dummy[i] = nullptr; + StrideDs_dummy[i] = I0; + }); + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); + const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + // Create gemm arguments with dummy values to check for validity + typename GridwiseGemmCTranspose::Argument gemm_arg{ + nullptr, // p_as_grid + nullptr, // p_bs_grid + p_ds_grid_dummy, // p_ds_grid + nullptr, // p_e_grid + GemmM, // M + GemmN, // N + GemmK, // K + I0, // StrideAs + I0, // StrideBs + StrideDs_dummy, // StrideDs + I0, // StrideE + arg.k_batch_, + AElementwiseOp{}, + BElementwiseOp{}, + CDEElementwiseOp{}}; + + if(!GridwiseGemmCTranspose::CheckValidity(gemm_arg)) + { + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3" + << (DirectLoad ? "_DirectLoad" : "") + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << MPerXdl << ", " + << NPerXdl << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { + str << ", TransposeTransferInScalarPerVectorAligned: " + << TransposeTransferInScalarPerVectorAligned <<", " + << "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned; + } + + + str << ">"; + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index e7e24b148a..d1015ee504 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -612,7 +612,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 } template - __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + __host__ __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) { return generate_tuple( @@ -1402,7 +1402,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 } template - __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + __device__ __host__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) { const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( @@ -1509,7 +1509,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const DsGridDesc_M_N& ds_grid_desc_m_n, - const CGridDesc_M_N& c_grid_desc_m_n) + const CGridDesc_M_N& c_grid_desc_m_n, + const index_t k_batch = 1, + const index_t k_idx = 0) { const auto a_grid_buf = make_dynamic_buffer( @@ -1538,6 +1540,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + //const index_t n_block_data_idx_on_grid =__builtin_amdgcn_readfirstlane(k_id * KPerBlock); + + const index_t num_ak0_per_block = + __builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch); + const index_t num_bk0_per_block = + __builtin_amdgcn_readfirstlane(b_grid_desc_bk0_n_bk1.GetLength(I0) / k_batch); + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); @@ -1571,7 +1580,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 2, ABlockTransferSrcScalarPerVector>( a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), + make_multi_index(num_ak0_per_block * k_idx, m_block_data_idx_on_grid, 0), a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0)); } @@ -1601,7 +1610,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), + make_multi_index(num_ak0_per_block * k_idx, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1627,7 +1636,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 2, BBlockTransferSrcScalarPerVector>( b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), + make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0), b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0)); } @@ -1657,7 +1666,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), + make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -1691,7 +1700,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); + (KPerBlock * k_batch)); + + // if(threadIdx.x == 0) { + // printf("num_k block main loop: %d\n m_block_data_idx_on_grid: %d\n n_block_data_idx_on_grid: %d\n", num_k_block_main_loop, m_block_data_idx_on_grid, n_block_data_idx_on_grid); + // } blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp index 2dd918461e..dbbebcd7e0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp @@ -5,7 +5,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -65,7 +65,7 @@ using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple< // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>> + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, S<2,2,2>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, 1, 1, true> // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, // DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 208793b415..449c376951 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -79,7 +79,7 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -95,8 +95,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( + // op_ptrs); } #endif } @@ -109,11 +109,11 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances( + // op_ptrs); } #endif if constexpr(is_same_v && is_same_v && @@ -124,23 +124,23 @@ struct DeviceOperationInstanceFactory< #ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( + // op_ptrs); } #endif } @@ -149,12 +149,12 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances( + // op_ptrs); } #endif } @@ -166,7 +166,7 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -174,7 +174,7 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -182,8 +182,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances( + // op_ptrs); } #endif } @@ -195,11 +195,11 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -207,11 +207,11 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -219,12 +219,12 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances( + // op_ptrs); } #endif } @@ -239,8 +239,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -248,8 +248,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -257,8 +257,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( + // op_ptrs); } #endif } @@ -270,12 +270,12 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances( + // op_ptrs); } #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 @@ -283,8 +283,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + // op_ptrs); } #endif if constexpr(is_same_v && is_same_v && @@ -295,23 +295,23 @@ struct DeviceOperationInstanceFactory< #ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances( + // op_ptrs); } #endif } @@ -320,12 +320,12 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances( + // op_ptrs); } #endif } @@ -337,8 +337,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -346,8 +346,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -355,8 +355,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances( + // op_ptrs); } #endif } @@ -368,12 +368,12 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -381,12 +381,12 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -394,12 +394,12 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances( + // op_ptrs); } #endif } @@ -417,10 +417,10 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + // op_ptrs); } #endif @@ -429,9 +429,9 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + // op_ptrs); } #endif } @@ -443,14 +443,14 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_f16_16_16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -458,10 +458,10 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + // op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_v3_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_INT8 @@ -469,9 +469,9 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( - op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); + // add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + // op_ptrs); } #endif } @@ -486,10 +486,10 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + // op_ptrs); } #endif @@ -498,10 +498,10 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( + // op_ptrs); } #endif } @@ -513,14 +513,14 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -528,10 +528,10 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_v3_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + // op_ptrs); } #endif #ifdef CK_ENABLE_INT8 @@ -539,10 +539,10 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( - op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( + // op_ptrs); + // add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( + // op_ptrs); } #endif } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 19e27cf173..6ec3c2279d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -32,6 +32,7 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index eceb70c05f..867604564d 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -210,6 +210,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, // workspace_sz will be equal to 0 for other layout than NGCHW const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); DeviceMem workspace_dev(workspace_sz); + // printf("run impl\n"); op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); if(op_ptr->IsSupportedArgument(argument_ptr.get())) @@ -224,8 +225,10 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, auto invoker_ptr = op_ptr->MakeInvokerPointer(); + // printf("prerun\n"); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + // printf("post run\n"); std::size_t flop = conv_param.GetFlops(); std::size_t num_btype = conv_param.GetByte(); diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 012d6e1502..2b2ed70339 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -14,118 +14,118 @@ message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}") message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}") set(PROFILER_OPS - profile_gemm.cpp - profile_reduce.cpp - profile_groupnorm_bwd_data.cpp - profile_groupnorm_fwd.cpp - profile_layernorm_bwd_data.cpp - profile_layernorm_bwd_gamma_beta.cpp - profile_groupnorm_bwd_gamma_beta.cpp - profile_layernorm_fwd.cpp - profile_max_pool2d_fwd.cpp - profile_pool3d_fwd.cpp - profile_avg_pool3d_bwd.cpp - profile_max_pool3d_bwd.cpp - profile_avg_pool2d_bwd.cpp - profile_max_pool2d_bwd.cpp - profile_softmax.cpp - profile_batchnorm_fwd.cpp - profile_batchnorm_bwd.cpp - profile_batchnorm_infer.cpp - profile_conv_tensor_rearrange.cpp - profile_transpose.cpp - profile_permute_scale.cpp - profile_gemm_quantization.cpp + # profile_gemm.cpp + # profile_reduce.cpp + # profile_groupnorm_bwd_data.cpp + # profile_groupnorm_fwd.cpp + # profile_layernorm_bwd_data.cpp + # profile_layernorm_bwd_gamma_beta.cpp + # profile_groupnorm_bwd_gamma_beta.cpp + # profile_layernorm_fwd.cpp + # profile_max_pool2d_fwd.cpp + # profile_pool3d_fwd.cpp + # profile_avg_pool3d_bwd.cpp + # profile_max_pool3d_bwd.cpp + # profile_avg_pool2d_bwd.cpp + # profile_max_pool2d_bwd.cpp + # profile_softmax.cpp + # profile_batchnorm_fwd.cpp + # profile_batchnorm_bwd.cpp + # profile_batchnorm_infer.cpp + # profile_conv_tensor_rearrange.cpp + # profile_transpose.cpp + # profile_permute_scale.cpp + # profile_gemm_quantization.cpp ) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp) - list(APPEND PROFILER_OPS profile_contraction_scale.cpp) + # list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp) + # list(APPEND PROFILER_OPS profile_contraction_scale.cpp) endif() if(CK_EXPERIMENTAL_BUILDER) - list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp) + # list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp) endif() endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_OPS profile_gemm_reduce.cpp) - list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_OPS profile_gemm_add.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm.cpp) - list(APPEND PROFILER_OPS profile_gemm_streamk.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) + # list(APPEND PROFILER_OPS profile_gemm_reduce.cpp) + # list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) + # list(APPEND PROFILER_OPS profile_gemm_add.cpp) + # list(APPEND PROFILER_OPS profile_grouped_gemm.cpp) + # list(APPEND PROFILER_OPS profile_gemm_streamk.cpp) + # list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp) + # list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp) + # list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp) + # list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp) + # list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) + # list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") - list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) - list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) - list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp) - list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp) + # list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) + # list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) + # list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp) + # list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx95") - list(APPEND PROFILER_OPS profile_gemm_mx.cpp) + # list(APPEND PROFILER_OPS profile_gemm_mx.cpp) endif() - list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) - list(APPEND PROFILER_OPS profile_gemm_add.cpp) - list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp) - list(APPEND PROFILER_OPS profile_gemm_splitk.cpp) - list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp) - list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp) - list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp) - list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp) + # list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp) + # list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) + # list(APPEND PROFILER_OPS profile_gemm_add.cpp) + # list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp) + # list(APPEND PROFILER_OPS profile_gemm_splitk.cpp) + # list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp) + # list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp) + # list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp) + # list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp) - list(APPEND PROFILER_OPS profile_conv_fwd.cpp) + # list(APPEND PROFILER_OPS profile_conv_fwd.cpp) endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")) - list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp) + # list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])") - list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp) + # list(APPEND PROFILER_OPS profile_gemm_multiply_mkultiply.cpp) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") - list(APPEND PROFILER_OPS profile_gemm_universal.cpp) - list(APPEND PROFILER_OPS profile_batched_gemm.cpp) - list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) - list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) - list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) + # list(APPEND PROFILER_OPS profile_gemm_universal.cpp) + # list(APPEND PROFILER_OPS profile_batched_gemm.cpp) + # list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) + # list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) + # list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) + # list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) + # list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) + # list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp) + # list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp) - list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp) + # list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp) + # list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) + # list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp) + # list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) - list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp) - list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) + # list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) + # list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp) + # list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp) + # list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp) + # list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp) + # list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp) + # list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) endif() - list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp) + # list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp) endif() if(DL_KERNELS) - list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) + # list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp) + # list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) endif() if(CK_ENABLE_INT8) - list(APPEND PROFILER_OPS profile_gemm_quantization.cpp) + # list(APPEND PROFILER_OPS profile_gemm_quantization.cpp) endif() set(PROFILER_SOURCES profiler.cpp) @@ -152,131 +152,131 @@ endif() set(DEVICE_INSTANCES "") -list(APPEND DEVICE_INSTANCES device_gemm_instance) -list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance) -list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance) -list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance) -list(APPEND DEVICE_INSTANCES device_softmax_instance) -list(APPEND DEVICE_INSTANCES device_reduce_instance) -list(APPEND DEVICE_INSTANCES device_batchnorm_instance) -list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance) -list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance) -list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance) -list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance) -list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance) -list(APPEND DEVICE_INSTANCES device_image_to_column_instance) -list(APPEND DEVICE_INSTANCES device_column_to_image_instance) -list(APPEND DEVICE_INSTANCES device_transpose_instance) -list(APPEND DEVICE_INSTANCES device_permute_scale_instance) +# list(APPEND DEVICE_INSTANCES device_gemm_instance) +# list(APPEND DEVICE_INSTANCES device_normalization_fwd_instance) +# list(APPEND DEVICE_INSTANCES device_normalization_bwd_data_instance) +# list(APPEND DEVICE_INSTANCES device_normalization_bwd_gamma_beta_instance) +# list(APPEND DEVICE_INSTANCES device_softmax_instance) +# list(APPEND DEVICE_INSTANCES device_reduce_instance) +# list(APPEND DEVICE_INSTANCES device_batchnorm_instance) +# list(APPEND DEVICE_INSTANCES device_pool2d_fwd_instance) +# list(APPEND DEVICE_INSTANCES device_pool3d_fwd_instance) +# list(APPEND DEVICE_INSTANCES device_avg_pool2d_bwd_instance) +# list(APPEND DEVICE_INSTANCES device_avg_pool3d_bwd_instance) +# list(APPEND DEVICE_INSTANCES device_max_pool_bwd_instance) +# list(APPEND DEVICE_INSTANCES device_image_to_column_instance) +# list(APPEND DEVICE_INSTANCES device_column_to_image_instance) +# list(APPEND DEVICE_INSTANCES device_transpose_instance) +# list(APPEND DEVICE_INSTANCES device_permute_scale_instance) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance) - list(APPEND DEVICE_INSTANCES device_contraction_scale_instance) + # list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance) + # list(APPEND DEVICE_INSTANCES device_contraction_scale_instance) endif() if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_INSTANCES device_gemm_add_instance) - list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance) - list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance) - list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance) - list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance) - list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_instance) + # list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) + # list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance) endif() - list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance) + # list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") - list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance) - list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx1[12]") - list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance) - list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx95") - list(APPEND DEVICE_INSTANCES device_gemm_mx_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_mx_instance) endif() - list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance) - list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance) - list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_instance) - list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance) - list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance) - list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance) - list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance) - list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance) + # list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance) + # list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance) + # list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance) list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance) endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" )) - list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])") - list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") - list(APPEND DEVICE_INSTANCES device_gemm_universal_instance) - list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) - list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) - list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) - list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_universal_instance) + # list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) + # list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) - list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance) + # list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) endif() - list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance) - list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) + # list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") if(CK_EXPERIMENTAL_BUILDER) - list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances) + # list(APPEND DEVICE_INSTANCES device_grouped_conv_fwd_tile_instances) endif() endif() if(DL_KERNELS) - list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance) - list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) + # list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance) + # list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) endif() if(CK_ENABLE_INT8) - list(APPEND DEVICE_INSTANCES device_quantization_instance) + # list(APPEND DEVICE_INSTANCES device_quantization_instance) endif() set(PROFILER_LIBS utility getopt::getopt)