From b9ee41c660fd4172803d6187e54408d701c6a5e4 Mon Sep 17 00:00:00 2001 From: kabrahamAMD Date: Wed, 19 Nov 2025 21:31:05 +0100 Subject: [PATCH] [CK_Builder ]fixed accidental drop of get_elementwise_operation during merge and added usage of get_elementwise_operation() to other builder instances (#3238) Fixed issues encountered during merge of #3192 * fixed accidental drop of get_elementwise_operation during merge and added call to get_elementwise_op to 4 other builders * run clang-format --------- Co-authored-by: Kevin Abraham [ROCm/composable_kernel commit: 964f8e1f60395a4fd8ecbfe8907bff1b8d881314] --- .../include/ck_tile/builder/conv_factory.hpp | 10 ++-- .../ck_tile/builder/conv_signature_utils.hpp | 47 +++++++++++++++++++ .../builder/test/test_conv_description.cpp | 1 + 3 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index d839518285..39260c8acd 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -58,6 +58,8 @@ #include "ck_tile/builder/types.hpp" #include "ck_tile/builder/versions.hpp" +#include "ck_tile/builder/conv_signature_utils.hpp" + namespace ck_tile::builder::factory_internal { // Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types. @@ -665,7 +667,7 @@ struct ConvFactory SPATIAL_DIM, ConvDirection::FORWARD>()); using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps; + using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = @@ -762,7 +764,7 @@ struct ConvFactory SPATIAL_DIM, ConvDirection::FORWARD>()); using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps; + using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = @@ -858,7 +860,7 @@ struct ConvFactory SPATIAL_DIM, ConvDirection::FORWARD>()); using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps; + using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = @@ -980,7 +982,7 @@ struct ConvFactory SPATIAL_DIM, ConvDirection::FORWARD>()); using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps; + using Ops = factory_internal::ElementwiseOps()>; using AlgorithmType = decltype(ALGORITHM); static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp new file mode 100644 index 0000000000..3ba2bf24dd --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp @@ -0,0 +1,47 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder { +/********************************************** + * constexpr helper functions for optional parameters + **********************************************/ + +template +concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; }; + +template +concept ProvidesConvolutionDirection = requires { Sig.direction; }; + +template +constexpr auto get_elementwise_operation() +{ + if constexpr(ProvidesElementwiseOperation) + { + return Sig.elementwise_operation; + } + else + { + return ElementwiseOperation::PASS_THROUGH; + } +} + +template +constexpr auto get_conv_direction() +{ + if constexpr(ProvidesConvolutionDirection) + { + return Sig.direction; + } + else + { + return ConvDirection::FORWARD; + } +} +} // namespace ck_tile::builder diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index b53cdc39c7..c2f7039348 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -49,6 +49,7 @@ struct ConvSignatureWithInvalidOptionalParams ckb::GroupConvDeviceOp device_operation = ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; }; + static_assert(!ckb::ConvSignatureDescriptor); struct DefaultAlgorithm