mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Add predicates for all device op instances.
This commit is contained in:
@@ -454,15 +454,16 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
auto VERSION>
|
||||
struct ConvFactory;
|
||||
|
||||
// Factory specialization for an instance of a grouped forward convolution kernel.
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
|
||||
// of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
/*static constexpr auto*/
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
|
||||
@@ -17,11 +17,11 @@
|
||||
// signature at compile time.
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/conv_signature_predicates.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
@@ -63,18 +63,7 @@ template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
//requires ConvDeviceOp<Sig.device_operation>;
|
||||
};
|
||||
|
||||
// Predicate for forward convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
// Predicate for backward data convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
// Predicate for backward weight convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
/**********************************************
|
||||
* Conv Direction Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for forward convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
// Predicate for backward data convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
// Predicate for backward weight convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
/**********************************************
|
||||
* Conv Fwd Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
|
||||
|
||||
// Generic predicate to check if signature uses any forward convolution device operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIsForward =
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<Sig>;
|
||||
|
||||
|
||||
/**********************************************
|
||||
* Conv Bwd Weight Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
|
||||
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
|
||||
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
|
||||
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
|
||||
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
|
||||
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
|
||||
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
|
||||
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
|
||||
|
||||
// Generic predicate to check if signature uses any backward weight convolution device operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIsBackwardWeight =
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl<Sig>;
|
||||
|
||||
/**********************************************
|
||||
* Conv Bwd Data Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
|
||||
(Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
|
||||
(Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
|
||||
(Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);
|
||||
|
||||
// Generic predicate to check if signature uses any backward data convolution device operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIsBackwardData =
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD<Sig> ||
|
||||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<Sig>;
|
||||
|
||||
/**********************************************
|
||||
* Generic Device Op Predicates
|
||||
**********************************************/
|
||||
|
||||
// Generic predicate to check if signature uses any device operation.
|
||||
template <auto Sig>
|
||||
concept IsValidConvDeviceOp =
|
||||
ConvDeviceOpIsForward<Sig> ||
|
||||
ConvDeviceOpIsBackwardData<Sig> ||
|
||||
ConvDeviceOpIsBackwardWeight<Sig>;
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -81,22 +81,22 @@ enum class FwdGroupConvDeviceOperation
|
||||
// Backward data convolution device operations.
|
||||
enum class BwdDataGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1,
|
||||
DeviceGroupedConvBwdDataMultipleD,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
};
|
||||
|
||||
// Backward weight convolution device operations.
|
||||
enum class BwdWeightGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvBwdWeight,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeight_Dl,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeightMultipleD,
|
||||
DeviceGroupedConvBwdWeight_Dl
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
|
||||
};
|
||||
|
||||
// Structural type for device operation
|
||||
|
||||
Reference in New Issue
Block a user