From c618d8bba37f22545828c410e3a48aac8d20dbca Mon Sep 17 00:00:00 2001 From: kabrahamAMD Date: Tue, 23 Dec 2025 10:39:49 +0100 Subject: [PATCH] [ck_builder] add utility functions to convolution (#3459) * reinstate conv_signature_utils.hpp * added tests for elementwise operation getters * add tests for getDataType functions * added test for no data type specified --------- Co-authored-by: Kevin Abraham [ROCm/composable_kernel commit: 4ce7d4c511c7e98a9ac01580ed1e9112e59061a0] --- .../builder/conv_signature_concepts.hpp | 20 +- .../ck_tile/builder/conv_signature_utils.hpp | 192 ++++++++++++++++++ .../builder/test/test_conv_description.cpp | 92 +++++++++ 3 files changed, 300 insertions(+), 4 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 8dc92c6bef..39e081ec8d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -80,12 +80,24 @@ concept ConvOutputLayout3D = (L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) || (L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided); +template +concept HasDataType = requires(T t) { + { t.data_type }; +}; + +// Note: for signature and TensorConfigDescriptor, +// it is not required to provide a default data type, but if one is provided, check if well defined +template +concept DataTypeWellDefinedIfProvided = requires(T t) { + requires !HasDataType || requires { + { t.data_type } -> std::convertible_to; + }; +}; + template concept TensorConfigDescriptor = requires(T t) { { t.layout } -> std::convertible_to; - // Only require that data type is defined. It might be set to undefined value, in which case the - // signature's data type is used. - { t.data_type } -> std::convertible_to; + requires DataTypeWellDefinedIfProvided; }; template @@ -164,11 +176,11 @@ concept HasElementwiseOpWithAuxiliaryOperands = requires(T t) { template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; - { t.data_type } -> std::convertible_to; { t.input } -> ConvTensorDescriptor; { t.weight } -> ConvTensorDescriptor; { t.output } -> ConvTensorDescriptor; requires ConvolutionDirectionWellDefinedIfProvided; + requires DataTypeWellDefinedIfProvided; }; // Concept to validate a convolution signature's values. diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp new file mode 100644 index 0000000000..7ff9f7f654 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/conv_signature_utils.hpp @@ -0,0 +1,192 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder { +/********************************************** + * constexpr helper functions for optional parameters + **********************************************/ + +template +concept ProvidesElementwiseOperation = requires { Sig.elementwise_operation; }; + +template +concept ProvidesDataType = requires { Sig.data_type; }; + +template +concept ConvTensorHasOp = requires { ConvTensor.operation; }; + +template +concept ProvidesConvolutionDirection = requires { Sig.direction; }; + +// returns elementwise operation for input tensor +// will defalut to signature's generic type if provided +// otherwise, default to PASS_THROUGH +template + requires ValidConvSignature +constexpr auto getInputElementwiseOperation() +{ + if constexpr(ConvTensorHasOp) + { + return Sig.input.operation.elementwise_operation; + } + else if constexpr(ProvidesElementwiseOperation) + { + return Sig.elementwise_operation; + } + else + { + return ElementwiseOperation::PASS_THROUGH; + } +} + +// returns elementwise operation for weight tensor +// will defalut to signature's generic type if provided +// otherwise, default to PASS_THROUGH +template + requires ValidConvSignature +constexpr auto getWeightElementwiseOperation() +{ + if constexpr(ConvTensorHasOp) + { + return Sig.weight.operation.elementwise_operation; + } + else if constexpr(ProvidesElementwiseOperation) + { + return Sig.elementwise_operation; + } + else + { + return ElementwiseOperation::PASS_THROUGH; + } +} + +// returns elementwise operation for output tensor +// will defalut to signature's generic type if provided +// otherwise, default to PASS_THROUGH +template + requires ValidConvSignature +constexpr auto getOutputElementwiseOperation() +{ + if constexpr(ConvTensorHasOp) + { + return Sig.output.operation.elementwise_operation; + } + else if constexpr(ProvidesElementwiseOperation) + { + return Sig.elementwise_operation; + } + else + { + return ElementwiseOperation::PASS_THROUGH; + } +} + +// returns convolution direction for signature. Will default to FORWARD if not provided by signature +template + requires ValidConvSignature +constexpr auto getConvDirection() +{ + if constexpr(ProvidesConvolutionDirection) + { + return Sig.direction; + } + else + { + return ConvDirection::FORWARD; + } +} + +// generic helper that returns data_type if provided and UNDEFINED otherwise +// can be used on both signature and TensorConfigDescriptor objects +template +constexpr auto getDataType() +{ + if constexpr(ProvidesDataType) + { + return TensorConfigOrSig.data_type; + } + else + { + return DataType::UNDEFINED_DATA_TYPE; + } +} + +// return data type of input tensor +template + requires ValidConvSignature +consteval auto getInputDataType() +{ + constexpr auto tensorDataType = getDataType(); + constexpr auto universalDataType = getDataType(); + if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE) + { + return tensorDataType; + } + else + { + return universalDataType; + } +} + +template + requires ValidConvSignature +consteval auto getWeightDataType() +{ + constexpr auto tensorDataType = getDataType(); + constexpr auto universalDataType = getDataType(); + if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE) + { + return tensorDataType; + } + else + { + return universalDataType; + } +} + +template + requires ValidConvSignature +consteval auto getOutputDataType() +{ + constexpr auto tensorDataType = getDataType(); + constexpr auto universalDataType = getDataType(); + if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE) + { + return tensorDataType; + } + else + { + return universalDataType; + } +} + +// returns data type if and only if all tensors have the same type. +// Otherwise, return DataType::UNDEFINED_DATA_TYPE +template + requires ValidConvSignature +consteval auto getDataTypeIfCommon() +{ + + auto inputDataType = getInputDataType(); + auto weightDataType = getWeightDataType(); + auto outputDataType = getOutputDataType(); + + if(inputDataType == weightDataType && inputDataType == outputDataType) + { + return inputDataType; + } + else + { + return DataType::UNDEFINED_DATA_TYPE; + } +} +} // namespace ck_tile::builder diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index dca0e858eb..5d6bc102e6 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -10,6 +10,7 @@ #include "testing_utils.hpp" #include "impl/conv_signature_types.hpp" #include "impl/conv_algorithm_types.hpp" +#include "ck_tile/builder/conv_signature_utils.hpp" namespace { @@ -35,6 +36,18 @@ struct TensorConfig ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE}; }; +struct TensorConfigNoDataType +{ + ckb::TensorLayout layout; + ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE}; +}; + +struct ConvTensorNoDataType +{ + TensorConfigNoDataType config; + TensorOp operation{}; +}; + struct ConvTensorSimple { TensorConfig config; @@ -155,6 +168,85 @@ struct DefaultAlgorithm }; static_assert(ckb::ConvAlgorithmDescriptor); +struct ConvSignatureUtilsTest1 +{ + using enum ckb::DataType; + using enum ckb::TensorLayout; + using enum ckb::ConvDirection; + using enum ckb::ElementwiseOperation; + + int spatial_dim = 2; + ckb::DataType data_type = FP16; + ckb::DataType accumulation_data_type = FP32; + ckb::ConvDirection direction = FORWARD; + ConvTensorWithOp input = { + .config = {GNHWC, FP16}, + }; + ConvTensorWithOp weight = {.config = {GKYXC, FP16}}; + ConvTensorWithOp output = {.config = {GNHWK, UNDEFINED_DATA_TYPE}, .operation = {SCALE}}; +}; + +static_assert(ckb::ConvSignatureDescriptor); + +struct ConvSignatureUtilsTest2 +{ + using enum ckb::DataType; + using enum ckb::TensorLayout; + using enum ckb::ConvDirection; + using enum ckb::ElementwiseOperation; + + int spatial_dim = 2; + ckb::DataType data_type = FP16; + ckb::ElementwiseOperation elementwise_operation = CONV_INVSCALE; + ckb::DataType accumulation_data_type = FP32; + ckb::ConvDirection direction = FORWARD; + ConvTensorSimple input = { + .config = {GNHWC, FP16}, + }; + ConvTensorNoDataType weight = {.config = {GKYXC}, .operation = {POWER}}; + ConvTensorWithOp output = {.config = {GNHWK, BF16}, .operation = {GELU}}; +}; + +static_assert(ckb::ConvSignatureDescriptor); + +TEST(ConvUtilsTest, getDataType1) +{ + using enum ckb::DataType; + static constexpr const ConvSignatureUtilsTest1 SIGNATURE; + EXPECT_THAT(ckb::getInputDataType(), FP16); + EXPECT_THAT(ckb::getWeightDataType(), FP16); + EXPECT_THAT(ckb::getOutputDataType(), FP16); + EXPECT_THAT(ckb::getDataTypeIfCommon(), FP16); +} + +TEST(ConvUtilsTest, getDataType2) +{ + using enum ckb::DataType; + static constexpr const ConvSignatureUtilsTest2 SIGNATURE; + EXPECT_THAT(ckb::getInputDataType(), FP16); + EXPECT_THAT(ckb::getWeightDataType(), FP16); + EXPECT_THAT(ckb::getOutputDataType(), BF16); + EXPECT_THAT(ckb::getDataTypeIfCommon(), UNDEFINED_DATA_TYPE); +} + +TEST(ConvUtilsTest, getElementwiseOperation1) +{ + using enum ckb::ElementwiseOperation; + static constexpr const ConvSignatureUtilsTest1 SIGNATURE; + EXPECT_THAT(ckb::getInputElementwiseOperation(), PASS_THROUGH); + EXPECT_THAT(ckb::getWeightElementwiseOperation(), PASS_THROUGH); + EXPECT_THAT(ckb::getOutputElementwiseOperation(), SCALE); +} + +TEST(ConvUtilsTest, getElementwiseOperation2) +{ + using enum ckb::ElementwiseOperation; + static constexpr const ConvSignatureUtilsTest2 SIGNATURE; + EXPECT_THAT(ckb::getInputElementwiseOperation(), CONV_INVSCALE); + EXPECT_THAT(ckb::getWeightElementwiseOperation(), POWER); + EXPECT_THAT(ckb::getOutputElementwiseOperation(), GELU); +} + TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription) { static constexpr const ConvSignature SIGNATURE;