mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor (#3331)
* Separate layouts into separate entities for input, weight, and output tensors. * Add test for handling bias tensor layouts. * Use instance string in builder tests. * Add handling of output bias data types and layouts. * Generalize handling of the elementwise ops. * Test fix. * Create builder for layouts. * Layout builder improvements. * Improve layout builder. * Simplify bias layout handling. * Code clean-up. * Move layout utils into separate file. * Remove hard-coded layout combinations. * Small code clean-up. * Move data type utils into a separate file. * Add data types, layouts, and elementwise ops per conv tensor. * Builder bug fixes after refactoring. * Working baseline. * Make signature definition look nice in the test code. * Move TensorConfig into test implementations. * Fix all fwd conv builder tests. * Fix conv traits and descriptors tests. * More factory assets under a separate directory. * Fix building conv traits. * Fix clang-format. * Add Readme doc to describe the design. * Add link to main Readme. Fix links in the builder design doc. * Clean-up data type/layout/elementwise op conversions. * Switch from dimension and tensor type specific layouts to a flat list of tensor layouts. * Fix clang-formatting. * Fix clang-format for test code. * Simplify fwd conv signature definitions in the test code. * Remove accidental edits. * Fix comment string. * Fix instance factory after rebase. * Fix tests after rebase. * Unify layout handling. * Add more conv layout unit tests. * Clang-format. * Fix merge conflicts. * Improve elementwise op handling. --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -16,40 +16,79 @@ namespace ckb = ck_tile::builder;
|
||||
namespace ckr = ck_tile::reflect;
|
||||
namespace ckt = ck_tile::test;
|
||||
|
||||
struct TensorOp
|
||||
{
|
||||
ckb::ElementwiseOperation elementwise_operation{ckb::ElementwiseOperation::PASS_THROUGH};
|
||||
};
|
||||
|
||||
struct InvalidTensorOp
|
||||
{
|
||||
int elementwise_operation = 7; // invalid value
|
||||
};
|
||||
static_assert(!ckb::TensorOperatorDescriptor<InvalidTensorOp>);
|
||||
|
||||
struct TensorConfig
|
||||
{
|
||||
ckb::TensorLayout layout;
|
||||
ckb::DataType data_type{ckb::DataType::UNDEFINDED};
|
||||
ckb::DataType compute_type{ckb::DataType::UNDEFINDED};
|
||||
};
|
||||
|
||||
struct ConvTensorSimple
|
||||
{
|
||||
TensorConfig config;
|
||||
};
|
||||
|
||||
struct ConvTensorWithOp
|
||||
{
|
||||
TensorConfig config;
|
||||
TensorOp operation{};
|
||||
};
|
||||
|
||||
struct ConvTensorWithInvalidOp
|
||||
{
|
||||
TensorConfig config;
|
||||
InvalidTensorOp operation{};
|
||||
};
|
||||
|
||||
// Defines the signature of the convolution operation to be tested.
|
||||
// This includes dimensionality, direction, data layout, and data type.
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
// ckb::GroupConvDeviceOp device_operation =
|
||||
// ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
|
||||
ConvTensorSimple input = {.config = {ckb::TensorLayout::GNHWC}};
|
||||
ConvTensorSimple weight = {.config = {ckb::TensorLayout::GKYXC}};
|
||||
ConvTensorSimple output = {.config = {ckb::TensorLayout::GNHWK}};
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
// Compile time tests for concepts
|
||||
struct ConvSignatureWithOptionalParams
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ConvTensorWithOp input = {
|
||||
.config = {ckb::TensorLayout::GNHWC, ckb::DataType::FP16},
|
||||
};
|
||||
ConvTensorWithOp weight = {.config = {ckb::TensorLayout::GKYXC, ckb::DataType::FP16}};
|
||||
ConvTensorWithOp output = {.config = {ckb::TensorLayout::GNHWK, ckb::DataType::FP16},
|
||||
.operation = {ckb::ElementwiseOperation::SCALE}};
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureWithOptionalParams>);
|
||||
|
||||
struct ConvSignatureWithInvalidOptionalParams
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
int elementwise_operation = 7; // this should fail
|
||||
// ckb::GroupConvDeviceOp device_operation =
|
||||
// ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
|
||||
ConvTensorWithInvalidOp input = {.config = {ckb::TensorLayout::GNHWC}};
|
||||
ConvTensorWithInvalidOp weight = {.config = {ckb::TensorLayout::GKYXC}};
|
||||
ConvTensorWithInvalidOp output = {.config = {ckb::TensorLayout::GNHWK}};
|
||||
};
|
||||
|
||||
static_assert(!ckb::ConvSignatureDescriptor<ConvSignatureWithInvalidOptionalParams>);
|
||||
|
||||
struct DefaultAlgorithm
|
||||
@@ -123,7 +162,9 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
|
||||
"│ ├─ Input Layout: GNHWC\n"
|
||||
"│ ├─ Weight Layout: GKYXC\n"
|
||||
"│ ├─ Output Layout: GNHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
|
||||
Reference in New Issue
Block a user