From 0688e667df35f08035704fac2d11271d149462a6 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski Date: Mon, 8 Dec 2025 16:50:58 +0000 Subject: [PATCH] refactotred int8 tuples --- ...conv_fwd_wmma_cshufflev3_comp_instance.hpp | 89 ------------------- ..._conv_fwd_wmma_cshufflev3_mem_instance.hpp | 80 ----------------- .../gpu/grouped_convolution_forward.hpp | 25 ++++-- ...volution_forward_clamp_wmma_cshufflev3.inc | 14 +++ ...nvolution_forward_comp_wmma_cshufflev3.inc | 32 ------- ...ed_convolution_forward_wmma_cshufflev3.inc | 32 +++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 16 +--- ...flev3_gnhwc_gkyxc_gnhwk_int8_instance.cpp} | 60 ++++++------- ...fflev3_nhwgc_gkyxc_nhwgk_int8_instance.in} | 28 +++--- .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 32 +------ ...ev3_gndhwc_gkzyxc_gndhwk_int8_instance.cpp | 36 ++++---- ...v3_ngcdhw_gkczyx_ngkdhw_int8_instance.cpp} | 86 +++++++----------- 12 files changed, 156 insertions(+), 374 deletions(-) delete mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp delete mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp delete mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_wmma_cshufflev3.inc rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/{comp/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp => device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_int8_instance.cpp} (77%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/{mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in => device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_instance.in} (80%) rename library/src/tensor_operation_instance/gpu/{grouped_conv2d_fwd/wmma/mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in => grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instance.cpp} (50%) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp deleted file mode 100644 index a23b7366fa..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_comp_instance.hpp +++ /dev/null @@ -1,89 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -#ifdef CK_ENABLE_FP8 -using F8 = ck::f8_t; -#endif - -#ifdef CK_ENABLE_BF8 -using BF8 = ck::bf8_t; -#endif - -using BF16 = ck::bhalf_t; -using F16 = ck::half_t; -using F32 = float; -using I8 = int8_t; - -template -using S = ck::Sequence; - -using Empty_Tuple = ck::Tuple<>; - -using namespace ck::tensor_layout::convolution; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddClamp = ck::tensor_operation::element_wise::AddClamp; -using Clamp = ck::tensor_operation::element_wise::Clamp; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; - -static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; -static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; - -template , - typename OutElementOp = PassThrough> -using device_grouped_conv_fwd_wmma_cshufflev3_int8_comp_instances = std::tuple< - // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> -#ifndef ONE_INSTANCE_PER_LIST - , - // "2x" instances - // DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, // Assert broken - // "part 2" instances - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> -#endif - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp deleted file mode 100644 index 6937bf25d7..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp +++ /dev/null @@ -1,80 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -#ifdef CK_ENABLE_FP8 -using F8 = ck::f8_t; -#endif - -#ifdef CK_ENABLE_BF8 -using BF8 = ck::bf8_t; -#endif - -using BF16 = ck::bhalf_t; -using F16 = ck::half_t; -using F32 = float; -using I8 = int8_t; - -template -using S = ck::Sequence; - -using Empty_Tuple = ck::Tuple<>; - -using namespace ck::tensor_layout::convolution; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddClamp = ck::tensor_operation::element_wise::AddClamp; -using Clamp = ck::tensor_operation::element_wise::Clamp; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; - -static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; -static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; - -template , - typename OutElementOp = PassThrough> -using device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances = std::tuple< - // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> -#ifndef ONE_INSTANCE_PER_LIST - , - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> -#endif - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 779757f384..ad720918c7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -29,7 +29,6 @@ #include "grouped_convolution_forward_wmma.inc" #endif #include "grouped_convolution_forward_wmma_cshufflev3.inc" -#include "grouped_convolution_forward_comp_wmma_cshufflev3.inc" #include "grouped_convolution_forward_mem_inter_wmma_cshufflev3.inc" #include "grouped_convolution_forward_mem_intra_wmma_cshufflev3.inc" #endif @@ -832,6 +831,15 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_int8_instances( + op_ptrs); + } #endif } @@ -872,12 +880,6 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instances( + op_ptrs); + } #endif } #endif // CK_USE_WMMA diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_wmma_cshufflev3.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_wmma_cshufflev3.inc index a364404218..8be584949f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_wmma_cshufflev3.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_wmma_cshufflev3.inc @@ -86,6 +86,20 @@ void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_i PassThrough, Clamp>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances); + // void // add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( // std::vector>>& instances); -#endif - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma_cshufflev3.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma_cshufflev3.inc index d059339eea..cba5e9b2cd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma_cshufflev3.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma_cshufflev3.inc @@ -90,6 +90,22 @@ void add_device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_f16_instanc PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_gnhwc_gkyxc_gnhwk_int8_instances( + std::vector>>& instances); +#endif + // grouped conv2d forward, NHWGC/GKYXC/NHWGK #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( @@ -336,6 +352,22 @@ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f8_insta F8>>>& instances); #endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_INT8 void add_device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_int8_instances( std::vector{}); - + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdDefault>{}); add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_int8_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0>{}); - + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1P0>{}); add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_int8_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0>{}); - + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1S1P0>{}); add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_int8_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_instance.in similarity index 80% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_instance.in index 19ea7f8162..4a7d67dd1a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_instance.in @@ -2,12 +2,12 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" #include "ck/utility/filter_tuple.hpp" namespace ck::tensor_operation::device::instance { -using device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances = +using device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_intra_instances = std::vector>>; template -void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances_shard( - device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances& instances) +void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_intra_instances_shard( + device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_intra_instances& instances) { add_device_operation_instances( instances, ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2, + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2, NHWGC, GKYXC, Empty_Tuple, NHWGK, - ConvFwdDefault, - Interwave>, + ConvFwdDefault>, Shards, ShardIndex>{}); add_device_operation_instances( instances, ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2, + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2, NHWGC, GKYXC, Empty_Tuple, NHWGK, - ConvFwd1x1P0, - Interwave>, + ConvFwd1x1P0>, Shards, ShardIndex>{}); add_device_operation_instances( instances, ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2, + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2, NHWGC, GKYXC, Empty_Tuple, NHWGK, - ConvFwd1x1S1P0, - Interwave>, + ConvFwd1x1S1P0>, Shards, ShardIndex>{}); add_device_operation_instances( instances, ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2, + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<2, NHWGC, GKYXC, Empty_Tuple, NHWGK, - ConvFwdOddC, - Interwave>, + ConvFwdOddC>, Shards, ShardIndex>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index ea6bf15eb4..c33a3c9a05 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -74,6 +74,7 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instance.cpp ) # Add generated files for sharded instantiations. include(ShardInstantiation) @@ -185,37 +186,6 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/wmma ) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances - TEMPLATE_FILE wmma/mem/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in - NUM_SHARDS 16 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/wmma/mem -) - -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances - TEMPLATE_FILE wmma/mem/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in - NUM_SHARDS 16 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/wmma/mem -) - -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances - TEMPLATE_FILE wmma/comp/device_grouped_conv3d_fwd_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in - NUM_SHARDS 16 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/wmma/comp -) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_comp_instances - TEMPLATE_FILE wmma/comp/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in - NUM_SHARDS 16 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/wmma/comp -) - if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_gndhwc_gkzyxc_gndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_gndhwc_gkzyxc_gndhwk_int8_instance.cpp index 922ff39aeb..8db4a50ad3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_gndhwc_gkzyxc_gndhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_gndhwc_gkzyxc_gndhwk_int8_instance.cpp @@ -25,28 +25,28 @@ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_gndhwc_gkzyxc_gndhwk_int8_ins { add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_int8_instances<3, - GNDHWC, - GKZYXC, - Empty_Tuple, - GNDHWK, - ConvFwdDefault>{}); + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwdDefault>{}); add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_int8_instances<3, - GNDHWC, - GKZYXC, - Empty_Tuple, - GNDHWK, - ConvFwd1x1P0>{}); + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1P0>{}); add_device_operation_instances( instances, - device_grouped_conv_fwd_wmma_cshufflev3_int8_instances<3, - GNDHWC, - GKZYXC, - Empty_Tuple, - GNDHWK, - ConvFwd1x1S1P0>{}); + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instance.cpp similarity index 50% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instance.cpp index 4438a2830f..fff285ef73 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/mem/device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_cshufflev3_ngcdhw_gkczyx_ngkdhw_int8_instance.cpp @@ -1,81 +1,55 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_mem_instance.hpp" -#include "ck/utility/filter_tuple.hpp" -namespace ck::tensor_operation::device::instance { +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { -using device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances = - std::vector>>; - -template -void add_device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances_shard( - device_grouped_conv2d_fwd_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances& instances) + PassThrough>>>& instances) { add_device_operation_instances( instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2, - NHWGC, - GKYXC, + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3, + NGCDHW, + GKCZYX, Empty_Tuple, - NHWGK, - ConvFwdDefault, - Intrawave>, - Shards, - ShardIndex>{}); - + NGKDHW, + ConvFwdDefault>{}); add_device_operation_instances( instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2, - NHWGC, - GKYXC, + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3, + NGCDHW, + GKCZYX, Empty_Tuple, - NHWGK, - ConvFwd1x1P0, - Intrawave>, - Shards, - ShardIndex>{}); - + NGKDHW, + ConvFwd1x1P0>{}); add_device_operation_instances( instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2, - NHWGC, - GKYXC, + device_grouped_conv_fwd_wmma_cshufflev3_int8_generic_instances<3, + NGCDHW, + GKCZYX, Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0, - Intrawave>, - Shards, - ShardIndex>{}); - - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_wmma_cshufflev3_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>, - Shards, - ShardIndex>{}); + NGKDHW, + ConvFwd1x1S1P0>{}); } -} // namespace ck::tensor_operation::device::instance +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck