[CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor (#3331)

* Separate layouts into separate entities for input, weight, and output tensors.

* Add test for handling bias tensor layouts.

* Use instance string in builder tests.

* Add handling of output bias data types and layouts.

* Generalize handling of the elementwise ops.

* Test fix.

* Create builder for layouts.

* Layout builder improvements.

* Improve layout builder.

* Simplify bias layout handling.

* Code clean-up.

* Move layout utils into separate file.

* Remove hard-coded layout combinations.

* Small code clean-up.

* Move data type utils into a separate file.

* Add data types, layouts, and elementwise ops per conv tensor.

* Builder bug fixes after refactoring.

* Working baseline.

* Make signature definition look nice in the test code.

* Move TensorConfig into test implementations.

* Fix all fwd conv builder tests.

* Fix conv traits and descriptors tests.

* More factory assets under a separate directory.

* Fix building conv traits.

* Fix clang-format.

* Add Readme doc to describe the design.

* Add link to main Readme. Fix links in the builder design doc.

* Clean-up data type/layout/elementwise op conversions.

* Switch from dimension and tensor type specific layouts to a flat list of tensor layouts.

* Fix clang-formatting.

* Fix clang-format for test code.

* Simplify fwd conv signature definitions in the test code.

* Remove accidental edits.

* Fix comment string.

* Fix instance factory after rebase.

* Fix tests after rebase.

* Unify layout handling.

* Add more conv layout unit tests.

* Clang-format.

* Fix merge conflicts.

* Improve elementwise op handling.

---------

Co-authored-by: Ville Pietilä <>

[ROCm/composable_kernel commit: 9cb1f421bc]
This commit is contained in:
Ville Pietilä
2025-12-04 12:58:31 +02:00
committed by GitHub
parent a8f5d21fb8
commit 419aa4e420
37 changed files with 1731 additions and 617 deletions

View File

@@ -85,7 +85,10 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_THAT(Traits::layout,
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
ck_tile::builder::TensorLayout::GKYXC,
ck_tile::builder::TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
@@ -212,7 +215,10 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_THAT(Traits::layout,
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
ck_tile::builder::TensorLayout::GKYXC,
ck_tile::builder::TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
@@ -295,7 +301,10 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_THAT(Traits::layout,
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
ck_tile::builder::TensorLayout::GKYXC,
ck_tile::builder::TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);