Add predicates for all device op instances.

This commit is contained in:
Ville Pietilä
2025-10-29 14:39:03 +00:00
parent e5eb4edd1a
commit 74ba32ea58
4 changed files with 174 additions and 22 deletions

View File

@@ -454,15 +454,16 @@ template <ConvSignatureDescriptor auto SIGNATURE,
auto VERSION>
struct ConvFactory;
// Factory specialization for an instance of a grouped forward convolution kernel.
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
/*static constexpr auto*/
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;

View File

@@ -17,11 +17,11 @@
// signature at compile time.
#pragma once
#include <variant>
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/conv_signature_predicates.hpp"
namespace ck_tile::builder {
@@ -63,18 +63,7 @@ template <auto Sig>
concept ValidConvSignature = requires {
requires ConvSpatialDim<Sig.spatial_dim>;
requires ConvDataType<Sig.data_type>;
//requires ConvDeviceOp<Sig.device_operation>;
};
// Predicate for forward convolution.
template <auto Sig>
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
// Predicate for backward data convolution.
template <auto Sig>
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
// Predicate for backward weight convolution.
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
} // namespace ck_tile::builder

View File

@@ -0,0 +1,162 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
/**********************************************
* Conv Direction Predicates
**********************************************/
// Predicate for forward convolution.
template <auto Sig>
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
// Predicate for backward data convolution.
template <auto Sig>
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
// Predicate for backward weight convolution.
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
/**********************************************
* Conv Fwd Device Op Predicates
**********************************************/
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
(Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
// Generic predicate to check if signature uses any forward convolution device operation.
template <auto Sig>
concept ConvDeviceOpIsForward =
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<Sig>;
/**********************************************
* Conv Bwd Weight Device Op Predicates
**********************************************/
// Predicate for DeviceGroupedConvBwdWeight operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
(Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
// Generic predicate to check if signature uses any backward weight convolution device operation.
template <auto Sig>
concept ConvDeviceOpIsBackwardWeight =
ConvDeviceOpIs_DeviceGroupedConvBwdWeight<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl<Sig>;
/**********************************************
* Conv Bwd Data Device Op Predicates
**********************************************/
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
(Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
(Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
(Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);
// Generic predicate to check if signature uses any backward data convolution device operation.
template <auto Sig>
concept ConvDeviceOpIsBackwardData =
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<Sig>;
/**********************************************
* Generic Device Op Predicates
**********************************************/
// Generic predicate to check if signature uses any device operation.
template <auto Sig>
concept IsValidConvDeviceOp =
ConvDeviceOpIsForward<Sig> ||
ConvDeviceOpIsBackwardData<Sig> ||
ConvDeviceOpIsBackwardWeight<Sig>;
} // namespace ck_tile::builder

View File

@@ -81,22 +81,22 @@ enum class FwdGroupConvDeviceOperation
// Backward data convolution device operations.
enum class BwdDataGroupConvDeviceOperation
{
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1,
DeviceGroupedConvBwdDataMultipleD,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
};
// Backward weight convolution device operations.
enum class BwdWeightGroupConvDeviceOperation
{
DeviceGroupedConvBwdWeight,
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
DeviceGroupedConvBwdWeight_Dl,
DeviceGroupedConvBwdWeight_Xdl_CShuffle,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3,
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
DeviceGroupedConvBwdWeightMultipleD,
DeviceGroupedConvBwdWeight_Dl
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
};
// Structural type for device operation