mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
[CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor (#3331)
* Separate layouts into separate entities for input, weight, and output tensors.
* Add test for handling bias tensor layouts.
* Use instance string in builder tests.
* Add handling of output bias data types and layouts.
* Generalize handling of the elementwise ops.
* Test fix.
* Create builder for layouts.
* Layout builder improvements.
* Improve layout builder.
* Simplify bias layout handling.
* Code clean-up.
* Move layout utils into separate file.
* Remove hard-coded layout combinations.
* Small code clean-up.
* Move data type utils into a separate file.
* Add data types, layouts, and elementwise ops per conv tensor.
* Builder bug fixes after refactoring.
* Working baseline.
* Make signature definition look nice in the test code.
* Move TensorConfig into test implementations.
* Fix all fwd conv builder tests.
* Fix conv traits and descriptors tests.
* More factory assets under a separate directory.
* Fix building conv traits.
* Fix clang-format.
* Add Readme doc to describe the design.
* Add link to main Readme. Fix links in the builder design doc.
* Clean-up data type/layout/elementwise op conversions.
* Switch from dimension and tensor type specific layouts to a flat list of tensor layouts.
* Fix clang-formatting.
* Fix clang-format for test code.
* Simplify fwd conv signature definitions in the test code.
* Remove accidental edits.
* Fix comment string.
* Fix instance factory after rebase.
* Fix tests after rebase.
* Unify layout handling.
* Add more conv layout unit tests.
* Clang-format.
* Fix merge conflicts.
* Improve elementwise op handling.
---------
Co-authored-by: Ville Pietilä <>
[ROCm/composable_kernel commit: 9cb1f421bc]
This commit is contained in:
@@ -85,7 +85,10 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
|
||||
ck_tile::builder::TensorLayout::GKYXC,
|
||||
ck_tile::builder::TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
@@ -212,7 +215,10 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
|
||||
ck_tile::builder::TensorLayout::GKYXC,
|
||||
ck_tile::builder::TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
@@ -295,7 +301,10 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
|
||||
ck_tile::builder::TensorLayout::GKYXC,
|
||||
ck_tile::builder::TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
Reference in New Issue
Block a user