mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[ck_builder] add utility functions to convolution (#3459)
* reinstate conv_signature_utils.hpp * added tests for elementwise operation getters * add tests for getDataType functions * added test for no data type specified --------- Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include "testing_utils.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -35,6 +36,18 @@ struct TensorConfig
|
||||
ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE};
|
||||
};
|
||||
|
||||
struct TensorConfigNoDataType
|
||||
{
|
||||
ckb::TensorLayout layout;
|
||||
ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE};
|
||||
};
|
||||
|
||||
struct ConvTensorNoDataType
|
||||
{
|
||||
TensorConfigNoDataType config;
|
||||
TensorOp operation{};
|
||||
};
|
||||
|
||||
struct ConvTensorSimple
|
||||
{
|
||||
TensorConfig config;
|
||||
@@ -155,6 +168,85 @@ struct DefaultAlgorithm
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
|
||||
|
||||
struct ConvSignatureUtilsTest1
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
using enum ckb::TensorLayout;
|
||||
using enum ckb::ConvDirection;
|
||||
using enum ckb::ElementwiseOperation;
|
||||
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = FP16;
|
||||
ckb::DataType accumulation_data_type = FP32;
|
||||
ckb::ConvDirection direction = FORWARD;
|
||||
ConvTensorWithOp input = {
|
||||
.config = {GNHWC, FP16},
|
||||
};
|
||||
ConvTensorWithOp weight = {.config = {GKYXC, FP16}};
|
||||
ConvTensorWithOp output = {.config = {GNHWK, UNDEFINED_DATA_TYPE}, .operation = {SCALE}};
|
||||
};
|
||||
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureUtilsTest1>);
|
||||
|
||||
struct ConvSignatureUtilsTest2
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
using enum ckb::TensorLayout;
|
||||
using enum ckb::ConvDirection;
|
||||
using enum ckb::ElementwiseOperation;
|
||||
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = CONV_INVSCALE;
|
||||
ckb::DataType accumulation_data_type = FP32;
|
||||
ckb::ConvDirection direction = FORWARD;
|
||||
ConvTensorSimple input = {
|
||||
.config = {GNHWC, FP16},
|
||||
};
|
||||
ConvTensorNoDataType weight = {.config = {GKYXC}, .operation = {POWER}};
|
||||
ConvTensorWithOp output = {.config = {GNHWK, BF16}, .operation = {GELU}};
|
||||
};
|
||||
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureUtilsTest2>);
|
||||
|
||||
TEST(ConvUtilsTest, getDataType1)
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
static constexpr const ConvSignatureUtilsTest1 SIGNATURE;
|
||||
EXPECT_THAT(ckb::getInputDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getWeightDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getOutputDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getDataTypeIfCommon<SIGNATURE>(), FP16);
|
||||
}
|
||||
|
||||
TEST(ConvUtilsTest, getDataType2)
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
static constexpr const ConvSignatureUtilsTest2 SIGNATURE;
|
||||
EXPECT_THAT(ckb::getInputDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getWeightDataType<SIGNATURE>(), FP16);
|
||||
EXPECT_THAT(ckb::getOutputDataType<SIGNATURE>(), BF16);
|
||||
EXPECT_THAT(ckb::getDataTypeIfCommon<SIGNATURE>(), UNDEFINED_DATA_TYPE);
|
||||
}
|
||||
|
||||
TEST(ConvUtilsTest, getElementwiseOperation1)
|
||||
{
|
||||
using enum ckb::ElementwiseOperation;
|
||||
static constexpr const ConvSignatureUtilsTest1 SIGNATURE;
|
||||
EXPECT_THAT(ckb::getInputElementwiseOperation<SIGNATURE>(), PASS_THROUGH);
|
||||
EXPECT_THAT(ckb::getWeightElementwiseOperation<SIGNATURE>(), PASS_THROUGH);
|
||||
EXPECT_THAT(ckb::getOutputElementwiseOperation<SIGNATURE>(), SCALE);
|
||||
}
|
||||
|
||||
TEST(ConvUtilsTest, getElementwiseOperation2)
|
||||
{
|
||||
using enum ckb::ElementwiseOperation;
|
||||
static constexpr const ConvSignatureUtilsTest2 SIGNATURE;
|
||||
EXPECT_THAT(ckb::getInputElementwiseOperation<SIGNATURE>(), CONV_INVSCALE);
|
||||
EXPECT_THAT(ckb::getWeightElementwiseOperation<SIGNATURE>(), POWER);
|
||||
EXPECT_THAT(ckb::getOutputElementwiseOperation<SIGNATURE>(), GELU);
|
||||
}
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
|
||||
{
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
|
||||
Reference in New Issue
Block a user