Remove explicit device op flag from from convolution signature.

This commit is contained in:
Ville Pietilä
2025-11-05 09:17:46 +00:00
parent f0291b7956
commit deec3a0dc1
16 changed files with 77 additions and 392 deletions

View File

@@ -95,7 +95,7 @@ concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
};
// No requirements yet for a ConvAlgorithm concept.
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this concept.
template <typename T>
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
@@ -183,4 +183,49 @@ concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<LoopScheduler>;
};
/******************************************** */
/* Concepts for the different device ops */
/******************************************** */
template <typename T>
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
ConvAlgorithmDescriptor<T> &&
SpecifiesThreadBlock<T> &&
SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> &&
SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> &&
SpecifiesBlockGemm<T> &&
SpecifiesGemmSpecialization<T>;
template <typename T>
concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvAlgorithmDescriptor<T> &&
SpecifiesThreadBlock<T> &&
SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> &&
SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> &&
SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> &&
SpecifiesNumGroupsToMerge<T> &&
SpecifiesLoopScheduler<T>;
template <typename T>
concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle =
ConvAlgorithmDescriptor<T> &&
SpecifiesThreadBlock<T> &&
SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> &&
SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> &&
SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> &&
SpecifiesLoopScheduler<T>;
} // namespace ck_tile::builder

View File

@@ -517,7 +517,12 @@ namespace ck_tile::builder {
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
auto VERSION>
struct ConvFactory;
struct ConvFactory
{
// This will trigger if a specialization for the given convolution direction is not found.
// We should always catch this in an earlier validation check.
static_assert(false, "Unsupported device operation.");
};
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
// of a grouped forward convolution kernel.
@@ -525,7 +530,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<SIGNATURE>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<decltype(ALGORITHM)>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
@@ -536,26 +541,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
"The convolution algorithm descriptor must specify gridwise GEMM info.");
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify block transfer info.");
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify LDS transfer info.");
static_assert(
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify thread cluster access order info.");
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify source access order info.");
static_assert(SpecifiesBlockGemm<AlgorithmType>,
"The convolution algorithm descriptor must specify block gemm pipeline.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify gemm specialization.");
static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load ==
ALGORITHM.block_transfer.lds_transfer_b.is_direct_load,
"A and B block transfers must both be direct load or not.");
@@ -647,7 +632,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<SIGNATURE>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<decltype(ALGORITHM)>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
@@ -658,31 +643,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
"The convolution algorithm descriptor must specify gridwise GEMM info.");
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify block transfer info.");
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify LDS transfer info.");
static_assert(
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify thread cluster access order info.");
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify source access order info.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify gemm specialization.");
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
"The convolution algorithm descriptor must specify number of prefetch stages.");
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
"The convolution algorithm descriptor must specify loop scheduler.");
static_assert(SpecifiesNumGroupsToMerge<AlgorithmType>,
"The convolution algorithm descriptor must specify number of groups to merge.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION =
@@ -769,7 +729,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<SIGNATURE>
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle<decltype(ALGORITHM)>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
@@ -780,27 +740,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseWmmaGemm<AlgorithmType>,
"The convolution algorithm descriptor must specify gridwise GEMM info.");
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify block transfer info.");
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify LDS transfer info.");
static_assert(
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify thread cluster access order info.");
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify source access order info.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
"The convolution algorithm descriptor must specify number of prefetch stages.");
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
"The convolution algorithm descriptor must specify loop scheduler.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION =

View File

@@ -21,7 +21,6 @@
#include <type_traits>
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/conv_signature_predicates.hpp"
namespace ck_tile::builder {
@@ -41,9 +40,6 @@ template <DataType T>
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
template <typename T>
concept ConvDeviceOp = std::same_as<std::remove_cvref_t<T>, GroupConvDeviceOp>;
template <typename T>
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
@@ -55,7 +51,6 @@ concept ConvSignatureDescriptor = requires(T t) {
{ t.layout } -> ConvLayout;
{ t.data_type } -> std::convertible_to<DataType>;
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
{ t.device_operation } -> ConvDeviceOp;
};
// Concept to validate a convolution signature's values.
@@ -63,7 +58,18 @@ template <auto Sig>
concept ValidConvSignature = requires {
requires ConvSpatialDim<Sig.spatial_dim>;
requires ConvDataType<Sig.data_type>;
requires IsValidConvDeviceOp<Sig>;
};
// Predicate for forward convolution.
template <auto Sig>
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
// Predicate for backward data convolution.
template <auto Sig>
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
// Predicate for backward weight convolution.
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
} // namespace ck_tile::builder

View File

@@ -1,174 +0,0 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
/**********************************************
* Conv Direction Predicates
**********************************************/
// Predicate for forward convolution.
template <auto Sig>
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
// Predicate for backward data convolution.
template <auto Sig>
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
// Predicate for backward weight convolution.
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
/**********************************************
* Conv Fwd Device Op Predicates
**********************************************/
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
// Generic predicate to check if signature uses any forward convolution device operation.
template <auto Sig>
concept ConvDeviceOpIsForward =
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<Sig>;
/**********************************************
* Conv Bwd Weight Device Op Predicates
**********************************************/
// Predicate for DeviceGroupedConvBwdWeight operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
// Generic predicate to check if signature uses any backward weight convolution device operation.
template <auto Sig>
concept ConvDeviceOpIsBackwardWeight =
ConvDeviceOpIs_DeviceGroupedConvBwdWeight<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl<Sig>;
/**********************************************
* Conv Bwd Data Device Op Predicates
**********************************************/
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);
// Generic predicate to check if signature uses any backward data convolution device operation.
template <auto Sig>
concept ConvDeviceOpIsBackwardData =
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<Sig>;
/**********************************************
* Generic Device Op Predicates
**********************************************/
// Generic predicate to check if signature uses any device operation.
template <auto Sig>
concept IsValidConvDeviceOp = ConvDeviceOpIsForward<Sig> || ConvDeviceOpIsBackwardData<Sig> ||
ConvDeviceOpIsBackwardWeight<Sig>;
} // namespace ck_tile::builder

View File

@@ -70,52 +70,6 @@ enum class ConvDirection
BACKWARD_WEIGHT
};
// Forward convolution device operations.
enum class FwdGroupConvDeviceOperation
{
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
};
// Backward data convolution device operations.
enum class BwdDataGroupConvDeviceOperation
{
DeviceGroupedConvBwdDataMultipleD,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
};
// Backward weight convolution device operations.
enum class BwdWeightGroupConvDeviceOperation
{
DeviceGroupedConvBwdWeight,
DeviceGroupedConvBwdWeight_Dl,
DeviceGroupedConvBwdWeight_Xdl_CShuffle,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3,
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
DeviceGroupedConvBwdWeightMultipleD,
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
};
// Structural type for device operation
struct GroupConvDeviceOp
{
union
{
FwdGroupConvDeviceOperation _fwd;
BwdDataGroupConvDeviceOperation _bwd_data;
BwdWeightGroupConvDeviceOperation _bwd_weight;
};
constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {}
constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {}
constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {}
};
// Fused element-wise operations.
enum class ElementwiseOperation
{

View File

@@ -18,9 +18,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::SCALE,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
.elementwise_operation = ElementwiseOperation::SCALE};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,

View File

@@ -17,9 +17,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::NWGC_GKXC_NWGK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
.thread_block = FwdThreadBlock_64x32x32,

View File

@@ -17,9 +17,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::GNWC_GKXC_GNWK,
.data_type = DataType::I8,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle FwdConvAlgorithm{
.thread_block = FwdThreadBlock_64x64x64,

View File

@@ -17,9 +17,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,
@@ -46,9 +44,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,

View File

@@ -16,9 +16,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,

View File

@@ -16,9 +16,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
.data_type = DataType::FP32,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_128x128x32,

View File

@@ -17,9 +17,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,

View File

@@ -17,9 +17,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_128x128x32,

View File

@@ -17,9 +17,7 @@ TEST(FwdConvInstances,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
.data_type = DataType::FP32,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock_256x256x32,

View File

@@ -126,26 +126,6 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
GemmSpecialization gemm_specialization;
BlockGemm block_gemm;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesBlockGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(ckb::SpecifiesGemmSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
@@ -158,30 +138,6 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
size_t num_groups_to_merge;
LoopScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesNumGroupsToMerge<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
{
@@ -193,25 +149,5 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
size_t num_gemm_k_prefetch_stages;
LoopScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesGridwiseWmmaGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
} // namespace ck_tile::builder::test

View File

@@ -17,7 +17,6 @@ struct ConvSignature
GroupConvLayout layout;
DataType data_type;
ElementwiseOperation elementwise_operation;
GroupConvDeviceOp device_operation;
};
static_assert(ConvSignatureDescriptor<ConvSignature>);