[ck_builder] add utility functions to convolution (#3459)

* reinstate conv_signature_utils.hpp

* added tests for elementwise operation getters

* add tests for getDataType functions

* added test for no data type specified

---------

Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
kabrahamAMD
2025-12-23 10:39:49 +01:00
committed by GitHub
parent ead81d1b0b
commit 4ce7d4c511
3 changed files with 300 additions and 4 deletions

View File

@@ -80,12 +80,24 @@ concept ConvOutputLayout3D =
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
template <typename T>
concept HasDataType = requires(T t) {
{ t.data_type };
};
// Note: for signature and TensorConfigDescriptor,
// it is not required to provide a default data type, but if one is provided, check if well defined
template <typename T>
concept DataTypeWellDefinedIfProvided = requires(T t) {
requires !HasDataType<T> || requires {
{ t.data_type } -> std::convertible_to<DataType>;
};
};
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<TensorLayout>;
// Only require that data type is defined. It might be set to undefined value, in which case the
// signature's data type is used.
{ t.data_type } -> std::convertible_to<DataType>;
requires DataTypeWellDefinedIfProvided<T>;
};
template <typename T>
@@ -164,11 +176,11 @@ concept HasElementwiseOpWithAuxiliaryOperands = requires(T t) {
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
{ t.data_type } -> std::convertible_to<DataType>;
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>;
requires DataTypeWellDefinedIfProvided<T>;
};
// Concept to validate a convolution signature's values.

View File

@@ -0,0 +1,192 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
/**********************************************
* constexpr helper functions for optional parameters
**********************************************/
template <auto Sig>
concept ProvidesElementwiseOperation = requires { Sig.elementwise_operation; };
template <auto Sig>
concept ProvidesDataType = requires { Sig.data_type; };
template <auto ConvTensor>
concept ConvTensorHasOp = requires { ConvTensor.operation; };
template <auto Sig>
concept ProvidesConvolutionDirection = requires { Sig.direction; };
// returns elementwise operation for input tensor
// will defalut to signature's generic type if provided
// otherwise, default to PASS_THROUGH
template <auto Sig>
requires ValidConvSignature<Sig>
constexpr auto getInputElementwiseOperation()
{
if constexpr(ConvTensorHasOp<Sig.input>)
{
return Sig.input.operation.elementwise_operation;
}
else if constexpr(ProvidesElementwiseOperation<Sig>)
{
return Sig.elementwise_operation;
}
else
{
return ElementwiseOperation::PASS_THROUGH;
}
}
// returns elementwise operation for weight tensor
// will defalut to signature's generic type if provided
// otherwise, default to PASS_THROUGH
template <auto Sig>
requires ValidConvSignature<Sig>
constexpr auto getWeightElementwiseOperation()
{
if constexpr(ConvTensorHasOp<Sig.weight>)
{
return Sig.weight.operation.elementwise_operation;
}
else if constexpr(ProvidesElementwiseOperation<Sig>)
{
return Sig.elementwise_operation;
}
else
{
return ElementwiseOperation::PASS_THROUGH;
}
}
// returns elementwise operation for output tensor
// will defalut to signature's generic type if provided
// otherwise, default to PASS_THROUGH
template <auto Sig>
requires ValidConvSignature<Sig>
constexpr auto getOutputElementwiseOperation()
{
if constexpr(ConvTensorHasOp<Sig.output>)
{
return Sig.output.operation.elementwise_operation;
}
else if constexpr(ProvidesElementwiseOperation<Sig>)
{
return Sig.elementwise_operation;
}
else
{
return ElementwiseOperation::PASS_THROUGH;
}
}
// returns convolution direction for signature. Will default to FORWARD if not provided by signature
template <auto Sig>
requires ValidConvSignature<Sig>
constexpr auto getConvDirection()
{
if constexpr(ProvidesConvolutionDirection<Sig>)
{
return Sig.direction;
}
else
{
return ConvDirection::FORWARD;
}
}
// generic helper that returns data_type if provided and UNDEFINED otherwise
// can be used on both signature and TensorConfigDescriptor objects
template <auto TensorConfigOrSig>
constexpr auto getDataType()
{
if constexpr(ProvidesDataType<TensorConfigOrSig>)
{
return TensorConfigOrSig.data_type;
}
else
{
return DataType::UNDEFINED_DATA_TYPE;
}
}
// return data type of input tensor
template <auto Sig>
requires ValidConvSignature<Sig>
consteval auto getInputDataType()
{
constexpr auto tensorDataType = getDataType<Sig.input.config>();
constexpr auto universalDataType = getDataType<Sig>();
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
{
return tensorDataType;
}
else
{
return universalDataType;
}
}
template <auto Sig>
requires ValidConvSignature<Sig>
consteval auto getWeightDataType()
{
constexpr auto tensorDataType = getDataType<Sig.weight.config>();
constexpr auto universalDataType = getDataType<Sig>();
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
{
return tensorDataType;
}
else
{
return universalDataType;
}
}
template <auto Sig>
requires ValidConvSignature<Sig>
consteval auto getOutputDataType()
{
constexpr auto tensorDataType = getDataType<Sig.output.config>();
constexpr auto universalDataType = getDataType<Sig>();
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
{
return tensorDataType;
}
else
{
return universalDataType;
}
}
// returns data type if and only if all tensors have the same type.
// Otherwise, return DataType::UNDEFINED_DATA_TYPE
template <auto Sig>
requires ValidConvSignature<Sig>
consteval auto getDataTypeIfCommon()
{
auto inputDataType = getInputDataType<Sig>();
auto weightDataType = getWeightDataType<Sig>();
auto outputDataType = getOutputDataType<Sig>();
if(inputDataType == weightDataType && inputDataType == outputDataType)
{
return inputDataType;
}
else
{
return DataType::UNDEFINED_DATA_TYPE;
}
}
} // namespace ck_tile::builder

View File

@@ -10,6 +10,7 @@
#include "testing_utils.hpp"
#include "impl/conv_signature_types.hpp"
#include "impl/conv_algorithm_types.hpp"
#include "ck_tile/builder/conv_signature_utils.hpp"
namespace {
@@ -35,6 +36,18 @@ struct TensorConfig
ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE};
};
struct TensorConfigNoDataType
{
ckb::TensorLayout layout;
ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE};
};
struct ConvTensorNoDataType
{
TensorConfigNoDataType config;
TensorOp operation{};
};
struct ConvTensorSimple
{
TensorConfig config;
@@ -155,6 +168,85 @@ struct DefaultAlgorithm
};
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
struct ConvSignatureUtilsTest1
{
using enum ckb::DataType;
using enum ckb::TensorLayout;
using enum ckb::ConvDirection;
using enum ckb::ElementwiseOperation;
int spatial_dim = 2;
ckb::DataType data_type = FP16;
ckb::DataType accumulation_data_type = FP32;
ckb::ConvDirection direction = FORWARD;
ConvTensorWithOp input = {
.config = {GNHWC, FP16},
};
ConvTensorWithOp weight = {.config = {GKYXC, FP16}};
ConvTensorWithOp output = {.config = {GNHWK, UNDEFINED_DATA_TYPE}, .operation = {SCALE}};
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureUtilsTest1>);
struct ConvSignatureUtilsTest2
{
using enum ckb::DataType;
using enum ckb::TensorLayout;
using enum ckb::ConvDirection;
using enum ckb::ElementwiseOperation;
int spatial_dim = 2;
ckb::DataType data_type = FP16;
ckb::ElementwiseOperation elementwise_operation = CONV_INVSCALE;
ckb::DataType accumulation_data_type = FP32;
ckb::ConvDirection direction = FORWARD;
ConvTensorSimple input = {
.config = {GNHWC, FP16},
};
ConvTensorNoDataType weight = {.config = {GKYXC}, .operation = {POWER}};
ConvTensorWithOp output = {.config = {GNHWK, BF16}, .operation = {GELU}};
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureUtilsTest2>);
TEST(ConvUtilsTest, getDataType1)
{
using enum ckb::DataType;
static constexpr const ConvSignatureUtilsTest1 SIGNATURE;
EXPECT_THAT(ckb::getInputDataType<SIGNATURE>(), FP16);
EXPECT_THAT(ckb::getWeightDataType<SIGNATURE>(), FP16);
EXPECT_THAT(ckb::getOutputDataType<SIGNATURE>(), FP16);
EXPECT_THAT(ckb::getDataTypeIfCommon<SIGNATURE>(), FP16);
}
TEST(ConvUtilsTest, getDataType2)
{
using enum ckb::DataType;
static constexpr const ConvSignatureUtilsTest2 SIGNATURE;
EXPECT_THAT(ckb::getInputDataType<SIGNATURE>(), FP16);
EXPECT_THAT(ckb::getWeightDataType<SIGNATURE>(), FP16);
EXPECT_THAT(ckb::getOutputDataType<SIGNATURE>(), BF16);
EXPECT_THAT(ckb::getDataTypeIfCommon<SIGNATURE>(), UNDEFINED_DATA_TYPE);
}
TEST(ConvUtilsTest, getElementwiseOperation1)
{
using enum ckb::ElementwiseOperation;
static constexpr const ConvSignatureUtilsTest1 SIGNATURE;
EXPECT_THAT(ckb::getInputElementwiseOperation<SIGNATURE>(), PASS_THROUGH);
EXPECT_THAT(ckb::getWeightElementwiseOperation<SIGNATURE>(), PASS_THROUGH);
EXPECT_THAT(ckb::getOutputElementwiseOperation<SIGNATURE>(), SCALE);
}
TEST(ConvUtilsTest, getElementwiseOperation2)
{
using enum ckb::ElementwiseOperation;
static constexpr const ConvSignatureUtilsTest2 SIGNATURE;
EXPECT_THAT(ckb::getInputElementwiseOperation<SIGNATURE>(), CONV_INVSCALE);
EXPECT_THAT(ckb::getWeightElementwiseOperation<SIGNATURE>(), POWER);
EXPECT_THAT(ckb::getOutputElementwiseOperation<SIGNATURE>(), GELU);
}
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
{
static constexpr const ConvSignature SIGNATURE;