From deec3a0dc1d235c8a8cfcdc451f352b38fe5890c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Wed, 5 Nov 2025 09:17:46 +0000 Subject: [PATCH] Remove explicit device op flag from from convolution signature. --- .../builder/conv_algorithm_concepts.hpp | 47 ++++- .../include/ck_tile/builder/conv_factory.hpp | 79 +------- .../builder/conv_signature_concepts.hpp | 18 +- .../builder/conv_signature_predicates.hpp | 174 ------------------ .../builder/include/ck_tile/builder/types.hpp | 46 ----- .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_1d_fp16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_1d_i8.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 8 +- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 4 +- .../test/impl/conv_algorithm_types.hpp | 64 ------- .../test/impl/conv_signature_types.hpp | 1 - 16 files changed, 77 insertions(+), 392 deletions(-) delete mode 100644 experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 365835684e..5c397b1162 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -95,7 +95,7 @@ concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; -// No requirements yet for a ConvAlgorithm concept. +// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this concept. template concept ConvAlgorithmDescriptor = std::is_class_v; @@ -183,4 +183,49 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; +/******************************************** */ +/* Concepts for the different device ops */ +/******************************************** */ + +template +concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + ConvAlgorithmDescriptor && + SpecifiesThreadBlock && + SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && + SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && + SpecifiesBlockGemm && + SpecifiesGemmSpecialization; + +template +concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + ConvAlgorithmDescriptor && + SpecifiesThreadBlock && + SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && + SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && + SpecifiesNumGroupsToMerge && + SpecifiesLoopScheduler; + +template +concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle = + ConvAlgorithmDescriptor && + SpecifiesThreadBlock && + SpecifiesGridwiseWmmaGemm && + SpecifiesBlockTransfer && + SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && + SpecifiesLoopScheduler; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 117faec689..c48228fa37 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -517,7 +517,12 @@ namespace ck_tile::builder { template -struct ConvFactory; +struct ConvFactory +{ + // This will trigger if a specialization for the given convolution direction is not found. + // We should always catch this in an earlier validation check. + static_assert(false, "Unsupported device operation."); +}; // Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance // of a grouped forward convolution kernel. @@ -525,7 +530,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -536,26 +541,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseXdlGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesBlockGemm, - "The convolution algorithm descriptor must specify block gemm pipeline."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load == ALGORITHM.block_transfer.lds_transfer_b.is_direct_load, "A and B block transfers must both be direct load or not."); @@ -647,7 +632,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -658,31 +643,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseXdlGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); - static_assert(SpecifiesNumPrefetchStages, - "The convolution algorithm descriptor must specify number of prefetch stages."); - static_assert(SpecifiesLoopScheduler, - "The convolution algorithm descriptor must specify loop scheduler."); - static_assert(SpecifiesNumGroupsToMerge, - "The convolution algorithm descriptor must specify number of groups to merge."); - static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = @@ -769,7 +729,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -780,27 +740,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseWmmaGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesNumPrefetchStages, - "The convolution algorithm descriptor must specify number of prefetch stages."); - static_assert(SpecifiesLoopScheduler, - "The convolution algorithm descriptor must specify loop scheduler."); - static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 742dfbb89c..983273b439 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -21,7 +21,6 @@ #include #include "ck_tile/builder/types.hpp" -#include "ck_tile/builder/conv_signature_predicates.hpp" namespace ck_tile::builder { @@ -41,9 +40,6 @@ template concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); -template -concept ConvDeviceOp = std::same_as, GroupConvDeviceOp>; - template concept ConvLayout = std::same_as, GroupConvLayout>; @@ -55,7 +51,6 @@ concept ConvSignatureDescriptor = requires(T t) { { t.layout } -> ConvLayout; { t.data_type } -> std::convertible_to; { t.elementwise_operation } -> std::convertible_to; - { t.device_operation } -> ConvDeviceOp; }; // Concept to validate a convolution signature's values. @@ -63,7 +58,18 @@ template concept ValidConvSignature = requires { requires ConvSpatialDim; requires ConvDataType; - requires IsValidConvDeviceOp; }; +// Predicate for forward convolution. +template +concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); + +// Predicate for backward data convolution. +template +concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); + +// Predicate for backward weight convolution. +template +concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp deleted file mode 100644 index f016a342d3..0000000000 --- a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/builder/types.hpp" - -namespace ck_tile::builder { - -/********************************************** - * Conv Direction Predicates - **********************************************/ - -// Predicate for forward convolution. -template -concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); - -// Predicate for backward data convolution. -template -concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); - -// Predicate for backward weight convolution. -template -concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); - -/********************************************** - * Conv Fwd Device Op Predicates - **********************************************/ - -// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); - -// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); - -// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); - -// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); - -// Generic predicate to check if signature uses any forward convolution device operation. -template -concept ConvDeviceOpIsForward = - ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; - -/********************************************** - * Conv Bwd Weight Device Op Predicates - **********************************************/ - -// Predicate for DeviceGroupedConvBwdWeight operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); - -// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); - -// Predicate for DeviceGroupedConvBwdWeightMultipleD operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); - -// Predicate for DeviceGroupedConvBwdWeight_Dl operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); - -// Generic predicate to check if signature uses any backward weight convolution device operation. -template -concept ConvDeviceOpIsBackwardWeight = - ConvDeviceOpIs_DeviceGroupedConvBwdWeight || - ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 || - ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl; - -/********************************************** - * Conv Bwd Data Device Op Predicates - **********************************************/ - -// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = - (Sig.device_operation._bwd_data == - BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); - -// Predicate for DeviceGroupedConvBwdDataMultipleD operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = - (Sig.device_operation._bwd_data == - BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); - -// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = - (Sig.device_operation._bwd_data == - BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); - -// Generic predicate to check if signature uses any backward data convolution device operation. -template -concept ConvDeviceOpIsBackwardData = - ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 || - ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD || - ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; - -/********************************************** - * Generic Device Op Predicates - **********************************************/ - -// Generic predicate to check if signature uses any device operation. -template -concept IsValidConvDeviceOp = ConvDeviceOpIsForward || ConvDeviceOpIsBackwardData || - ConvDeviceOpIsBackwardWeight; - -} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 2650f0de16..f09d740d20 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -70,52 +70,6 @@ enum class ConvDirection BACKWARD_WEIGHT }; -// Forward convolution device operations. -enum class FwdGroupConvDeviceOperation -{ - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor -}; - -// Backward data convolution device operations. -enum class BwdDataGroupConvDeviceOperation -{ - DeviceGroupedConvBwdDataMultipleD, - DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 -}; - -// Backward weight convolution device operations. -enum class BwdWeightGroupConvDeviceOperation -{ - DeviceGroupedConvBwdWeight, - DeviceGroupedConvBwdWeight_Dl, - DeviceGroupedConvBwdWeight_Xdl_CShuffle, - DeviceGroupedConvBwdWeight_Xdl_CShuffleV3, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, - DeviceGroupedConvBwdWeightMultipleD, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, -}; - -// Structural type for device operation -struct GroupConvDeviceOp -{ - union - { - FwdGroupConvDeviceOperation _fwd; - BwdDataGroupConvDeviceOperation _bwd_data; - BwdWeightGroupConvDeviceOperation _bwd_weight; - }; - - constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {} - constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {} - constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {} -}; - // Fused element-wise operations. enum class ElementwiseOperation { diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index 262b7349db..d6cda1f427 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -18,9 +18,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::SCALE, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + .elementwise_operation = ElementwiseOperation::SCALE}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp index d54c296a42..330db8d457 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp @@ -17,9 +17,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout1D::NWGC_GKXC_NWGK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ .thread_block = FwdThreadBlock_64x32x32, diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp index 336c8c3501..1ec5bbb349 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp @@ -17,9 +17,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout1D::GNWC_GKXC_GNWK, .data_type = DataType::I8, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle FwdConvAlgorithm{ .thread_block = FwdThreadBlock_64x64x64, diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 057efd73b1..31f2976fd0 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -17,9 +17,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, @@ -46,9 +44,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index b6241eca1c..6276424a77 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -16,9 +16,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index 766c3adc44..a390510199 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -16,9 +16,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_128x128x32, diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index 4de9dfddcb..3c59ae24fb 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -17,9 +17,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 541d62c44f..14d2811918 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -17,9 +17,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_128x128x32, diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index d4306ed981..bce092d5f6 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -17,9 +17,7 @@ TEST(FwdConvInstances, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock_256x256x32, diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 1a78028862..e719db89ed 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -126,26 +126,6 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 GemmSpecialization gemm_specialization; BlockGemm block_gemm; }; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert( - ckb::SpecifiesThreadBlock); -static_assert( - ckb::SpecifiesGridwiseXdlGemm); -static_assert( - ckb::SpecifiesBlockTransfer); -static_assert( - ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); -static_assert( - ckb::SpecifiesSourceAccessOrder); -static_assert(ckb::SpecifiesFwdConcSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); -static_assert( - ckb::SpecifiesBlockGemm); -static_assert(ckb::SpecifiesGemmSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { @@ -158,30 +138,6 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle size_t num_groups_to_merge; LoopScheduler loop_scheduler; }; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert( - ckb::SpecifiesThreadBlock); -static_assert( - ckb::SpecifiesGridwiseXdlGemm); -static_assert( - ckb::SpecifiesBlockTransfer); -static_assert( - ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>); -static_assert( - ckb::SpecifiesSourceAccessOrder); -static_assert(ckb::SpecifiesFwdConcSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>); -static_assert( - ckb::SpecifiesNumPrefetchStages); -static_assert( - ckb::SpecifiesGemmSpecialization); -static_assert( - ckb::SpecifiesLoopScheduler); -static_assert( - ckb::SpecifiesNumGroupsToMerge); struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle { @@ -193,25 +149,5 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle size_t num_gemm_k_prefetch_stages; LoopScheduler loop_scheduler; }; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert(ckb::SpecifiesThreadBlock); -static_assert( - ckb::SpecifiesGridwiseWmmaGemm); -static_assert( - ckb::SpecifiesBlockTransfer); -static_assert(ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder< - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>); -static_assert( - ckb::SpecifiesSourceAccessOrder); -static_assert( - ckb::SpecifiesFwdConcSpecialization); -static_assert( - ckb::SpecifiesNumPrefetchStages); -static_assert( - ckb::SpecifiesGemmSpecialization); -static_assert( - ckb::SpecifiesLoopScheduler); } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index 5e6684c4cd..71f16aefbe 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -17,7 +17,6 @@ struct ConvSignature GroupConvLayout layout; DataType data_type; ElementwiseOperation elementwise_operation; - GroupConvDeviceOp device_operation; }; static_assert(ConvSignatureDescriptor);