From 51b6f6fe7d1ff0eccbea213e0ed78c9cc52cd14a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Fri, 31 Oct 2025 01:13:58 +0200 Subject: [PATCH] [CK_BUILDER] Generalize convolution factory to build arbitrary device operations. (#3116) Generalize the current convolution factory in CK Builder to be able to build instances of any relevant convolution device operation. The main changes are: * Added new enums FwdGroupConvDeviceOperation, BwdDataGroupConvDeviceOperation, and * BwdWeightGroupConvDeviceOperation that contain the device operations for which the builder should be able to build instances. * Create a union structure GroupConvDeviceOp that can represent a single value of the fwd, bwd weight, or bwd data device operations. This would be more naturally represented by std::variant object, but we cannot use std::variant in NTTPs because it is not a structural object. * Introduced a new member device_operation in the ConvSignatureDescriptor concept that assumes GroupConvDeviceOp value. * Added predicates to be used in creation ConvFactory specialization for the different device operation. When we add support for a new device operation, we'll just create a new ConvFactory specialization with appropriate predicates. * Changed handling of the convolution layouts (GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D) to use the union based handling, i.e., there's now a GroupConvLayout union struct that can hold a single value of the 1D, 2D, or 3D layouts. This simplifies the handling of the different layouts as we get rid of templatized convolution signature. These code changes allow developers to work more easily in parallel when adding new device operations. * Fix building CK Builder instance traits after the introduction of direct load template parameter in CK. * Fix clang-formatting. [ROCm/composable_kernel commit: b387249fd905b595f2d38ac2a18d8c2aa9b88c00] --- .../include/ck_tile/builder/conv_factory.hpp | 33 +++- .../builder/conv_signature_concepts.hpp | 25 +-- .../builder/conv_signature_predicates.hpp | 174 ++++++++++++++++++ ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 10 +- .../builder/include/ck_tile/builder/types.hpp | 60 ++++++ .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 12 +- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 6 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 6 +- .../test/impl/conv_signature_types.hpp | 9 +- .../test/utils/ckb_conv_test_common.hpp | 2 +- 14 files changed, 318 insertions(+), 43 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index de8ba4f648..31be8c322c 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -158,6 +158,28 @@ struct ConvTensorLayouts +consteval auto GetTensorLayout() +{ + + if constexpr(SPATIAL_DIM == 1) + { + return factory_internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 2) + { + return factory_internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 3) + { + return factory_internal::ConvTensorLayouts{}; + } + else + { + static_assert(false, "Unsupported spatial dimension for convolution layout."); + } +} + // Type mappings from builder convolution data type to CK tensor types. template struct ConvTensorTypes @@ -432,16 +454,19 @@ template struct ConvFactory; -// Factory specialization for an instance of a grouped forward convolution kernel. +// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance +// of a grouped forward convolution kernel. template - requires ConvDirectionIsForward + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = - factory_internal::ConvTensorLayouts; + using Layouts = decltype(factory_internal::GetTensorLayout()); using Types = factory_internal::ConvTensorTypes; using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 0851f0061e..370e7b6521 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -21,6 +21,7 @@ #include #include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/conv_signature_predicates.hpp" namespace ck_tile::builder { @@ -40,16 +41,21 @@ template concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); +template +concept ConvDeviceOp = std::same_as, GroupConvDeviceOp>; + +template +concept ConvLayout = std::same_as, GroupConvLayout>; + // Concept for a type that defines a convolution's operational signature. template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; { t.direction } -> std::convertible_to; - requires std::convertible_to || - std::convertible_to || - std::convertible_to; + { t.layout } -> ConvLayout; { t.data_type } -> std::convertible_to; { t.elementwise_operation } -> std::convertible_to; + { t.device_operation } -> ConvDeviceOp; }; // Concept to validate a convolution signature's values. @@ -57,18 +63,7 @@ template concept ValidConvSignature = requires { requires ConvSpatialDim; requires ConvDataType; + requires IsValidConvDeviceOp; }; -// Predicate for forward convolution. -template -concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); - -// Predicate for backward data convolution. -template -concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); - -// Predicate for backward weight convolution. -template -concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp new file mode 100644 index 0000000000..f947c7e329 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder { + +/********************************************** + * Conv Direction Predicates + **********************************************/ + +// Predicate for forward convolution. +template +concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); + +// Predicate for backward data convolution. +template +concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); + +// Predicate for backward weight convolution. +template +concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); + +/********************************************** + * Conv Fwd Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); + +// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); + +// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); + +// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); + +// Generic predicate to check if signature uses any forward convolution device operation. +template +concept ConvDeviceOpIsForward = + ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + +/********************************************** + * Conv Bwd Weight Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvBwdWeight operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); + +// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); + +// Predicate for DeviceGroupedConvBwdWeightMultipleD operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); + +// Predicate for DeviceGroupedConvBwdWeight_Dl operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); + +// Generic predicate to check if signature uses any backward weight convolution device operation. +template +concept ConvDeviceOpIsBackwardWeight = + ConvDeviceOpIs_DeviceGroupedConvBwdWeight || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl; + +/********************************************** + * Conv Bwd Data Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = + (Sig.device_operation._bwd_data == + BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); + +// Predicate for DeviceGroupedConvBwdDataMultipleD operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = + (Sig.device_operation._bwd_data == + BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); + +// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = + (Sig.device_operation._bwd_data == + BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); + +// Generic predicate to check if signature uses any backward data convolution device operation. +template +concept ConvDeviceOpIsBackwardData = + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 || + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD || + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + +/********************************************** + * Generic Device Op Predicates + **********************************************/ + +// Generic predicate to check if signature uses any device operation. +template +concept IsValidConvDeviceOp = ConvDeviceOpIsForward || ConvDeviceOpIsBackwardData || + ConvDeviceOpIsBackwardWeight; + +} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 21201b8d50..9ab827e3a5 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -69,7 +69,8 @@ template + typename BComputeDataType, + bool DirectLoad> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; } // namespace ck::tensor_operation::device @@ -124,7 +125,8 @@ template + typename BComputeDataType_, + bool DirectLoad> struct InstanceTraits> + BComputeDataType_, + DirectLoad>> { // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; @@ -336,6 +339,7 @@ struct InstanceTraits(); // 47. AComputeDataType oss << "," << detail::type_name(); // 48. BComputeDataType + oss << "," << (DirectLoad ? "true" : "false"); // 49. DirectLoad oss << ">"; return oss.str(); diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 7f49e77f81..47bd8327d4 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -48,6 +48,20 @@ enum class GroupConvLayout3D NGCDHW_GKCZYX_NGKDHW, }; +struct GroupConvLayout +{ + union + { + GroupConvLayout1D _1d; + GroupConvLayout2D _2d; + GroupConvLayout3D _3d; + }; + + constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {} + constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {} + constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {} +}; + // Direction of the convolution operation. enum class ConvDirection { @@ -56,6 +70,52 @@ enum class ConvDirection BACKWARD_WEIGHT }; +// Forward convolution device operations. +enum class FwdGroupConvDeviceOperation +{ + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +}; + +// Backward data convolution device operations. +enum class BwdDataGroupConvDeviceOperation +{ + DeviceGroupedConvBwdDataMultipleD, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 +}; + +// Backward weight convolution device operations. +enum class BwdWeightGroupConvDeviceOperation +{ + DeviceGroupedConvBwdWeight, + DeviceGroupedConvBwdWeight_Dl, + DeviceGroupedConvBwdWeight_Xdl_CShuffle, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, + DeviceGroupedConvBwdWeightMultipleD, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, +}; + +// Structural type for device operation +struct GroupConvDeviceOp +{ + union + { + FwdGroupConvDeviceOperation _fwd; + BwdDataGroupConvDeviceOperation _bwd_data; + BwdWeightGroupConvDeviceOperation _bwd_weight; + }; + + constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {} + constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {} + constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {} +}; + // Fused element-wise operations. enum class ElementwiseOperation { diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index d5b8802896..77ff0fe28f 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -9,12 +9,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 1, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::SCALE}; + .elementwise_operation = ElementwiseOperation::SCALE, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 77c5c80489..5be7d5e604 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -28,12 +30,14 @@ TEST(FwdConvInstances, TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index c81d7543bb..4abe3df40d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -7,12 +7,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index d55a120bb8..5ea804cf8b 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -7,12 +7,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index f7bcf49e54..c729148346 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 27b5ddc821..832acd7412 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index c0b6f04383..9d0e107dbc 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index 297f827395..cc5490c711 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -3,11 +3,13 @@ #pragma once +#include #include "ck_tile/builder/conv_signature_concepts.hpp" namespace ck_tile::builder::test { -template +using namespace ck_tile::builder; + struct ConvSignature { int spatial_dim; @@ -15,9 +17,8 @@ struct ConvSignature GroupConvLayout layout; DataType data_type; ElementwiseOperation elementwise_operation; + GroupConvDeviceOp device_operation; }; -static_assert(ConvSignatureDescriptor>); -static_assert(ConvSignatureDescriptor>); -static_assert(ConvSignatureDescriptor>); +static_assert(ConvSignatureDescriptor); } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index 7ad01bd922..cd3943d26f 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -11,7 +11,7 @@ using namespace ck_tile::builder; using namespace test; // Common test implementation -template