mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[ck_builder] add utility functions to convolution (#3459)
* reinstate conv_signature_utils.hpp * added tests for elementwise operation getters * add tests for getDataType functions * added test for no data type specified --------- Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
@@ -80,12 +80,24 @@ concept ConvOutputLayout3D =
|
||||
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
|
||||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
|
||||
|
||||
template <typename T>
|
||||
concept HasDataType = requires(T t) {
|
||||
{ t.data_type };
|
||||
};
|
||||
|
||||
// Note: for signature and TensorConfigDescriptor,
|
||||
// it is not required to provide a default data type, but if one is provided, check if well defined
|
||||
template <typename T>
|
||||
concept DataTypeWellDefinedIfProvided = requires(T t) {
|
||||
requires !HasDataType<T> || requires {
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<TensorLayout>;
|
||||
// Only require that data type is defined. It might be set to undefined value, in which case the
|
||||
// signature's data type is used.
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
requires DataTypeWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -164,11 +176,11 @@ concept HasElementwiseOpWithAuxiliaryOperands = requires(T t) {
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
{ t.input } -> ConvTensorDescriptor;
|
||||
{ t.weight } -> ConvTensorDescriptor;
|
||||
{ t.output } -> ConvTensorDescriptor;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>;
|
||||
requires DataTypeWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
/**********************************************
|
||||
* constexpr helper functions for optional parameters
|
||||
**********************************************/
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesElementwiseOperation = requires { Sig.elementwise_operation; };
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesDataType = requires { Sig.data_type; };
|
||||
|
||||
template <auto ConvTensor>
|
||||
concept ConvTensorHasOp = requires { ConvTensor.operation; };
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesConvolutionDirection = requires { Sig.direction; };
|
||||
|
||||
// returns elementwise operation for input tensor
|
||||
// will defalut to signature's generic type if provided
|
||||
// otherwise, default to PASS_THROUGH
|
||||
template <auto Sig>
|
||||
requires ValidConvSignature<Sig>
|
||||
constexpr auto getInputElementwiseOperation()
|
||||
{
|
||||
if constexpr(ConvTensorHasOp<Sig.input>)
|
||||
{
|
||||
return Sig.input.operation.elementwise_operation;
|
||||
}
|
||||
else if constexpr(ProvidesElementwiseOperation<Sig>)
|
||||
{
|
||||
return Sig.elementwise_operation;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
// returns elementwise operation for weight tensor
|
||||
// will defalut to signature's generic type if provided
|
||||
// otherwise, default to PASS_THROUGH
|
||||
template <auto Sig>
|
||||
requires ValidConvSignature<Sig>
|
||||
constexpr auto getWeightElementwiseOperation()
|
||||
{
|
||||
if constexpr(ConvTensorHasOp<Sig.weight>)
|
||||
{
|
||||
return Sig.weight.operation.elementwise_operation;
|
||||
}
|
||||
else if constexpr(ProvidesElementwiseOperation<Sig>)
|
||||
{
|
||||
return Sig.elementwise_operation;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
// returns elementwise operation for output tensor
|
||||
// will defalut to signature's generic type if provided
|
||||
// otherwise, default to PASS_THROUGH
|
||||
template <auto Sig>
|
||||
requires ValidConvSignature<Sig>
|
||||
constexpr auto getOutputElementwiseOperation()
|
||||
{
|
||||
if constexpr(ConvTensorHasOp<Sig.output>)
|
||||
{
|
||||
return Sig.output.operation.elementwise_operation;
|
||||
}
|
||||
else if constexpr(ProvidesElementwiseOperation<Sig>)
|
||||
{
|
||||
return Sig.elementwise_operation;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
// returns convolution direction for signature. Will default to FORWARD if not provided by signature
|
||||
template <auto Sig>
|
||||
requires ValidConvSignature<Sig>
|
||||
constexpr auto getConvDirection()
|
||||
{
|
||||
if constexpr(ProvidesConvolutionDirection<Sig>)
|
||||
{
|
||||
return Sig.direction;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvDirection::FORWARD;
|
||||
}
|
||||
}
|
||||
|
||||
// generic helper that returns data_type if provided and UNDEFINED otherwise
|
||||
// can be used on both signature and TensorConfigDescriptor objects
|
||||
template <auto TensorConfigOrSig>
|
||||
constexpr auto getDataType()
|
||||
{
|
||||
if constexpr(ProvidesDataType<TensorConfigOrSig>)
|
||||
{
|
||||
return TensorConfigOrSig.data_type;
|
||||
}
|
||||
else
|
||||
{
|
||||
return DataType::UNDEFINED_DATA_TYPE;
|
||||
}
|
||||
}
|
||||
|
||||
// return data type of input tensor
|
||||
template <auto Sig>
|
||||
requires ValidConvSignature<Sig>
|
||||
consteval auto getInputDataType()
|
||||
{
|
||||
constexpr auto tensorDataType = getDataType<Sig.input.config>();
|
||||
constexpr auto universalDataType = getDataType<Sig>();
|
||||
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return tensorDataType;
|
||||
}
|
||||
else
|
||||
{
|
||||
return universalDataType;
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
requires ValidConvSignature<Sig>
|
||||
consteval auto getWeightDataType()
|
||||
{
|
||||
constexpr auto tensorDataType = getDataType<Sig.weight.config>();
|
||||
constexpr auto universalDataType = getDataType<Sig>();
|
||||
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return tensorDataType;
|
||||
}
|
||||
else
|
||||
{
|
||||
return universalDataType;
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
requires ValidConvSignature<Sig>
|
||||
consteval auto getOutputDataType()
|
||||
{
|
||||
constexpr auto tensorDataType = getDataType<Sig.output.config>();
|
||||
constexpr auto universalDataType = getDataType<Sig>();
|
||||
if constexpr(tensorDataType != DataType::UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return tensorDataType;
|
||||
}
|
||||
else
|
||||
{
|
||||
return universalDataType;
|
||||
}
|
||||
}
|
||||
|
||||
// returns data type if and only if all tensors have the same type.
|
||||
// Otherwise, return DataType::UNDEFINED_DATA_TYPE
|
||||
template <auto Sig>
|
||||
requires ValidConvSignature<Sig>
|
||||
consteval auto getDataTypeIfCommon()
|
||||
{
|
||||
|
||||
auto inputDataType = getInputDataType<Sig>();
|
||||
auto weightDataType = getWeightDataType<Sig>();
|
||||
auto outputDataType = getOutputDataType<Sig>();
|
||||
|
||||
if(inputDataType == weightDataType && inputDataType == outputDataType)
|
||||
{
|
||||
return inputDataType;
|
||||
}
|
||||
else
|
||||
{
|
||||
return DataType::UNDEFINED_DATA_TYPE;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile::builder
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "testing_utils.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -35,6 +36,18 @@ struct TensorConfig
|
||||
ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE};
|
||||
};
|
||||
|
||||
struct TensorConfigNoDataType
|
||||
{
|
||||
ckb::TensorLayout layout;
|
||||
ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE};
|
||||
};
|
||||
|
||||
struct ConvTensorNoDataType
|
||||
{
|
||||
TensorConfigNoDataType config;
|
||||
TensorOp operation{};
|
||||
};
|
||||
|
||||
struct ConvTensorSimple
|
||||
{
|
||||
TensorConfig config;
|
||||
@@ -155,6 +168,85 @@ struct DefaultAlgorithm
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
|
||||
|
||||
struct ConvSignatureUtilsTest1
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
using enum ckb::TensorLayout;
|
||||
using enum ckb::ConvDirection;
|
||||
using enum ckb::ElementwiseOperation;
|
||||
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = FP16;
|
||||
ckb::DataType accumulation_data_type = FP32;
|
||||
ckb::ConvDirection direction = FORWARD;
|
||||
ConvTensorWithOp input = {
|
||||
.config = {GNHWC, FP16},
|
||||
};
|
||||
ConvTensorWithOp weight = {.config = {GKYXC, FP16}};
|
||||
ConvTensorWithOp output = {.config = {GNHWK, UNDEFINED_DATA_TYPE}, .operation = {SCALE}};
|
||||
};
|
||||
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureUtilsTest1>);
|
||||
|
||||
struct ConvSignatureUtilsTest2
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
using enum ckb::TensorLayout;
|
||||
using enum ckb::ConvDirection;
|
||||
using enum ckb::ElementwiseOperation;
|
||||
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = CONV_INVSCALE;
|
||||
ckb::DataType accumulation_data_type = FP32;
|
||||
ckb::ConvDirection direction = FORWARD;
|
||||
ConvTensorSimple input = {
|
||||
.config = {GNHWC, FP16},
|
||||
};
|
||||
ConvTensorNoDataType weight = {.config = {GKYXC}, .operation = {POWER}};
|
||||
ConvTensorWithOp output = {.config = {GNHWK, BF16}, .operation = {GELU}};
|
||||
};
|
||||
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureUtilsTest2>);
|
||||
|
||||
TEST(ConvUtilsTest, getDataType1)
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
static constexpr const ConvSignatureUtilsTest1 SIGNATURE;
|
||||
EXPECT_THAT(ckb::getInputDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getWeightDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getOutputDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getDataTypeIfCommon<SIGNATURE>(), FP16);
|
||||
}
|
||||
|
||||
TEST(ConvUtilsTest, getDataType2)
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
static constexpr const ConvSignatureUtilsTest2 SIGNATURE;
|
||||
EXPECT_THAT(ckb::getInputDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getWeightDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getOutputDataType<SIGNATURE>(), BF16);
|
||||
EXPECT_THAT(ckb::getDataTypeIfCommon<SIGNATURE>(), UNDEFINED_DATA_TYPE);
|
||||
}
|
||||
|
||||
TEST(ConvUtilsTest, getElementwiseOperation1)
|
||||
{
|
||||
using enum ckb::ElementwiseOperation;
|
||||
static constexpr const ConvSignatureUtilsTest1 SIGNATURE;
|
||||
EXPECT_THAT(ckb::getInputElementwiseOperation<SIGNATURE>(), PASS_THROUGH);
|
||||
EXPECT_THAT(ckb::getWeightElementwiseOperation<SIGNATURE>(), PASS_THROUGH);
|
||||
EXPECT_THAT(ckb::getOutputElementwiseOperation<SIGNATURE>(), SCALE);
|
||||
}
|
||||
|
||||
TEST(ConvUtilsTest, getElementwiseOperation2)
|
||||
{
|
||||
using enum ckb::ElementwiseOperation;
|
||||
static constexpr const ConvSignatureUtilsTest2 SIGNATURE;
|
||||
EXPECT_THAT(ckb::getInputElementwiseOperation<SIGNATURE>(), CONV_INVSCALE);
|
||||
EXPECT_THAT(ckb::getWeightElementwiseOperation<SIGNATURE>(), POWER);
|
||||
EXPECT_THAT(ckb::getOutputElementwiseOperation<SIGNATURE>(), GELU);
|
||||
}
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
|
||||
{
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
|
||||
Reference in New Issue
Block a user