diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 252d423716..573577f2ee 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -454,15 +454,16 @@ template struct ConvFactory; -// Factory specialization for an instance of a grouped forward convolution kernel. +// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance +// of a grouped forward convolution kernel. template - requires ConvDirectionIsForward + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - /*static constexpr auto*/ using Layouts = decltype(factory_internal::GetTensorLayout()); using Types = factory_internal::ConvTensorTypes; using Ops = factory_internal::ElementwiseOps; diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 76e5590ad6..7864cde1ae 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -17,11 +17,11 @@ // signature at compile time. #pragma once -#include #include #include #include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/conv_signature_predicates.hpp" namespace ck_tile::builder { @@ -63,18 +63,7 @@ template concept ValidConvSignature = requires { requires ConvSpatialDim; requires ConvDataType; + //requires ConvDeviceOp; }; -// Predicate for forward convolution. -template -concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); - -// Predicate for backward data convolution. -template -concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); - -// Predicate for backward weight convolution. -template -concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp new file mode 100644 index 0000000000..9b47d87329 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + + +#include +#include + +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder { + +/********************************************** + * Conv Direction Predicates + **********************************************/ + +// Predicate for forward convolution. +template +concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); + +// Predicate for backward data convolution. +template +concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); + +// Predicate for backward weight convolution. +template +concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); + +/********************************************** + * Conv Fwd Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); + +// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); + +// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); + +// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); + +// Generic predicate to check if signature uses any forward convolution device operation. +template +concept ConvDeviceOpIsForward = + ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + + +/********************************************** + * Conv Bwd Weight Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvBwdWeight operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); + +// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); + +// Predicate for DeviceGroupedConvBwdWeightMultipleD operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); + +// Predicate for DeviceGroupedConvBwdWeight_Dl operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = + (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); + +// Generic predicate to check if signature uses any backward weight convolution device operation. +template +concept ConvDeviceOpIsBackwardWeight = + ConvDeviceOpIs_DeviceGroupedConvBwdWeight || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl; + +/********************************************** + * Conv Bwd Data Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = + (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); + +// Predicate for DeviceGroupedConvBwdDataMultipleD operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = + (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); + +// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = + (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); + +// Generic predicate to check if signature uses any backward data convolution device operation. +template +concept ConvDeviceOpIsBackwardData = + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 || + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD || + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + +/********************************************** + * Generic Device Op Predicates + **********************************************/ + +// Generic predicate to check if signature uses any device operation. +template +concept IsValidConvDeviceOp = + ConvDeviceOpIsForward || + ConvDeviceOpIsBackwardData || + ConvDeviceOpIsBackwardWeight; + +} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 509f240edd..7c0e23abde 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -81,22 +81,22 @@ enum class FwdGroupConvDeviceOperation // Backward data convolution device operations. enum class BwdDataGroupConvDeviceOperation { - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, DeviceGroupedConvBwdDataMultipleD, - DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 }; // Backward weight convolution device operations. enum class BwdWeightGroupConvDeviceOperation { DeviceGroupedConvBwdWeight, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, + DeviceGroupedConvBwdWeight_Dl, DeviceGroupedConvBwdWeight_Xdl_CShuffle, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, DeviceGroupedConvBwdWeightMultipleD, - DeviceGroupedConvBwdWeight_Dl + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, }; // Structural type for device operation