[CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor (#3331)

* Separate layouts into separate entities for input, weight, and output tensors.

* Add test for handling bias tensor layouts.

* Use instance string in builder tests.

* Add handling of output bias data types and layouts.

* Generalize handling of the elementwise ops.

* Test fix.

* Create builder for layouts.

* Layout builder improvements.

* Improve layout builder.

* Simplify bias layout handling.

* Code clean-up.

* Move layout utils into separate file.

* Remove hard-coded layout combinations.

* Small code clean-up.

* Move data type utils into a separate file.

* Add data types, layouts, and elementwise ops per conv tensor.

* Builder bug fixes after refactoring.

* Working baseline.

* Make signature definition look nice in the test code.

* Move TensorConfig into test implementations.

* Fix all fwd conv builder tests.

* Fix conv traits and descriptors tests.

* More factory assets under a separate directory.

* Fix building conv traits.

* Fix clang-format.

* Add Readme doc to describe the design.

* Add link to main Readme. Fix links in the builder design doc.

* Clean-up data type/layout/elementwise op conversions.

* Switch from dimension and tensor type specific layouts to a flat list of tensor layouts.

* Fix clang-formatting.

* Fix clang-format for test code.

* Simplify fwd conv signature definitions in the test code.

* Remove accidental edits.

* Fix comment string.

* Fix instance factory after rebase.

* Fix tests after rebase.

* Unify layout handling.

* Add more conv layout unit tests.

* Clang-format.

* Fix merge conflicts.

* Improve elementwise op handling.

---------

Co-authored-by: Ville Pietilä <>
This commit is contained in:
Ville Pietilä
2025-12-04 12:58:31 +02:00
committed by GitHub
parent 583fafc803
commit 9cb1f421bc
37 changed files with 1731 additions and 617 deletions

View File

@@ -119,6 +119,7 @@ add_ck_builder_test(test_ckb_instance_string
# Tests the forward convolution builder across multiple data types and dimensions.
# Individual tests are split into separate files to enable parallel compilation.
add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp
conv/test_ckb_conv_fwd_1d_fp16.cpp
conv/test_ckb_conv_fwd_1d_bf16.cpp
conv/test_ckb_conv_fwd_1d_i8.cpp

View File

@@ -13,11 +13,15 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::SCALE};
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.data_type = DataType::BF16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NGCW}},
.weight = {.config = {.layout = TensorLayout::GKXC}},
.output = {.config = {.layout = TensorLayout::NGKW},
.operation = {.elementwise_operation = ElementwiseOperation::SCALE}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
@@ -30,10 +34,13 @@ TEST(FwdConvInstances,
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"256,256,256,32",
"NGCW,GKXC,EmptyTuple,NGKW",
"PassThrough,PassThrough,Scale",
"Filter1x1Stride1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v2"});
"MNKPadding",
"Intrawave",
"v2"});
}
} // namespace

View File

@@ -10,14 +10,15 @@ using namespace ck_tile::builder::test_utils;
// 1D FP16 (channels-last) with DEFAULT specialization
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst_scale)
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::NWGC_GKXC_NWGK,
.data_type = DataType::FP16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NWGC}},
.weight = {.config = {.layout = TensorLayout::GKXC}},
.output = {.config = {.layout = TensorLayout::NWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
@@ -28,8 +29,12 @@ TEST(FwdConvInstances,
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>(
{"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "64, 64, 32, 32", "Default"});
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
"NWGC,GKXC,EmptyTuple,NWGK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding",
"64,64,32,32",
"Default"});
}
} // namespace

View File

@@ -14,12 +14,13 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::GNWC_GKXC_GNWK,
.data_type = DataType::I8,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.data_type = DataType::I8,
.accumulation_data_type = DataType::INT32,
.input = {.config = {.layout = TensorLayout::GNWC}},
.weight = {.config = {.layout = TensorLayout::GKXC}},
.output = {.config = {.layout = TensorLayout::GNWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
@@ -30,8 +31,11 @@ TEST(FwdConvInstances,
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>(
{"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", "128, 64, 64, 64", "Default"});
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle",
"128,64,64,64",
"GNWC,GKXC,EmptyTuple,GNWK",
"PassThrough,PassThrough,PassThrough",
"Default"});
}
#endif

View File

@@ -12,12 +12,13 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::BF16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
@@ -29,22 +30,26 @@ TEST(FwdConvInstances,
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"256,256,256,32",
"Default",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v1"});
"NHWGC,GKYXC,EmptyTuple,NHWGK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding",
"Intrawave",
"v1"});
}
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::BF16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
@@ -57,7 +62,10 @@ TEST(FwdConvInstances,
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"Filter3x3",
"BlkGemmPipelineVersion: v5"});
"NHWGC,GKYXC,EmptyTuple,NHWGK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding",
"v5"});
}
} // namespace

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
using namespace ck_tile::builder;
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_BF16_scale_add_relu)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::BF16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC, .data_type = DataType::BF16}},
.output = ConvolutionTensor{
.config = {.layout = TensorLayout::NHWGK},
.operation = TensorOperation<>{.elementwise_operation =
ElementwiseOperation::SCALEADD_SCALEADD_RELU}
.with_auxiliary_operand_configs<TensorLayout::NHWGK,
TensorLayout::G_K_strided>()}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x2_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
"NHWGC,GKYXC,Tuple(NHWGK,G_K),NHWGK",
"PassThrough,PassThrough,ScaleAddScaleAddRelu",
"64,64,32,32",
"MNKPadding",
"Default"});
}
} // namespace

View File

@@ -10,12 +10,13 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNHWC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
@@ -26,19 +27,24 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
.with_dl_transfer(DlFwdTransfer);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>(
{"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Default"});
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
"256,128,128,16",
"Default",
"MNKPadding",
"GNHWC,GKYXC,EmptyTuple,GNHWK",
"PassThrough,PassThrough,PassThrough"});
}
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNHWC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
@@ -50,8 +56,12 @@ TEST(FwdConvInstances,
.with_dl_transfer(DlFwdTransfer);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>(
{"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Filter1x1Pad0"});
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
"256,128,128,16",
"Filter1x1Pad0",
"MNKPadding",
"GNHWC,GKYXC,EmptyTuple,GNHWK",
"PassThrough,PassThrough,PassThrough"});
}
} // namespace

View File

@@ -11,12 +11,13 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNHWC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
@@ -29,10 +30,13 @@ TEST(FwdConvInstances,
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"256,256,256,32",
"Filter1x1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v3"});
"Intrawave",
"v3",
"GNHWC,GKYXC,EmptyTuple,GNHWK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
} // namespace

View File

@@ -11,12 +11,13 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
.data_type = DataType::FP32,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP32,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NGCHW}},
.weight = {.config = {.layout = TensorLayout::GKCYX}},
.output = {.config = {.layout = TensorLayout::NGKHW}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
@@ -29,10 +30,13 @@ TEST(FwdConvInstances,
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 128, 128, 32",
"256,128,128,32",
"Filter1x1Stride1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v4"});
"Intrawave",
"v4",
"NGCHW,GKCYX,EmptyTuple,NGKHW",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
} // namespace

View File

@@ -12,12 +12,13 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::FP8,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP8,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
@@ -28,8 +29,12 @@ TEST(FwdConvInstances,
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>(
{"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "256, 256, 128, 32", "Default"});
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
"256,256,128,32",
"Default",
"NHWGC,GKYXC,EmptyTuple,NHWGK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
} // namespace

View File

@@ -11,12 +11,13 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNHWC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
@@ -30,20 +31,24 @@ TEST(FwdConvInstances,
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
"256, 256, 128, 32",
"Default"});
"256,256,128,32",
"Default",
"GNHWC,GKYXC,EmptyTuple,GNHWK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
TEST(
FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNHWC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
@@ -57,8 +62,11 @@ TEST(
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
"128, 128, 128, 32",
"Filter1x1Pad0"});
"128,128,128,32",
"Filter1x1Pad0",
"GNHWC,GKYXC,EmptyTuple,GNHWK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
} // namespace

View File

@@ -12,12 +12,14 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
.data_type = DataType::BF16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.data_type = DataType::BF16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNDHWC}},
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
.output = {.config = {.layout = TensorLayout::GNDHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
@@ -29,10 +31,13 @@ TEST(FwdConvInstances,
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"256,256,256,32",
"Default",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v3"});
"Intrawave",
"v3",
"GNDHWC,GKZYXC,EmptyTuple,GNDHWK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
} // namespace

View File

@@ -12,12 +12,14 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
.data_type = DataType::FP16,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NDHWGC}},
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
.output = {.config = {.layout = TensorLayout::NDHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
@@ -30,10 +32,13 @@ TEST(FwdConvInstances,
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 128, 128, 32",
"256,128,128,32",
"Filter1x1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v4"});
"Intrawave",
"v4",
"NDHWGC,GKZYXC,EmptyTuple,NDHWGK",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
} // namespace

View File

@@ -12,12 +12,14 @@ using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
.data_type = DataType::FP32,
.elementwise_operation =
ElementwiseOperation::PASS_THROUGH};
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP32,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NGCDHW}},
.weight = {.config = {.layout = TensorLayout::GKCZYX}},
.output = {.config = {.layout = TensorLayout::NGKDHW}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
@@ -30,10 +32,13 @@ TEST(FwdConvInstances,
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"256,256,256,32",
"Filter1x1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v1"});
"Intrawave",
"v1",
"NGCDHW,GKCZYX,EmptyTuple,NGKDHW",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
} // namespace

View File

@@ -85,7 +85,10 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_THAT(Traits::layout,
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
ck_tile::builder::TensorLayout::GKYXC,
ck_tile::builder::TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
@@ -212,7 +215,10 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_THAT(Traits::layout,
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
ck_tile::builder::TensorLayout::GKYXC,
ck_tile::builder::TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
@@ -295,7 +301,10 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_THAT(Traits::layout,
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
ck_tile::builder::TensorLayout::GKYXC,
ck_tile::builder::TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);

View File

@@ -10,14 +10,48 @@ namespace ck_tile::builder::test {
using namespace ck_tile::builder;
struct TensorConfig
{
TensorLayout layout;
// Optional data types, override the type defined in the signature if provided.
DataType data_type{DataType::UNDEFINDED};
DataType compute_type{DataType::UNDEFINDED};
};
template <TensorConfig... Configs>
struct TensorOperation
{
ElementwiseOperation elementwise_operation{ElementwiseOperation::PASS_THROUGH};
std::array<TensorConfig, sizeof...(Configs)> auxiliary_operand_configs{Configs...};
// Add builder to add auxiliary tensor configs
template <auto... AuxiliaryConfigs>
constexpr auto with_auxiliary_operand_configs() const
{
return TensorOperation<Configs..., TensorConfig{AuxiliaryConfigs}...>{
.elementwise_operation = this->elementwise_operation};
}
};
template <typename Op = TensorOperation<>>
struct ConvolutionTensor
{
TensorConfig config;
Op operation{};
};
template <typename InputTensor = ConvolutionTensor<>,
typename WeightTensor = ConvolutionTensor<>,
typename OutputTensor = ConvolutionTensor<>>
struct ConvSignature
{
int spatial_dim;
ConvDirection direction;
GroupConvLayout layout;
DataType data_type;
ElementwiseOperation elementwise_operation;
DataType accumulation_data_type;
InputTensor input;
WeightTensor weight;
OutputTensor output;
};
static_assert(ConvSignatureDescriptor<ConvSignature>);
} // namespace ck_tile::builder::test

View File

@@ -16,40 +16,79 @@ namespace ckb = ck_tile::builder;
namespace ckr = ck_tile::reflect;
namespace ckt = ck_tile::test;
struct TensorOp
{
ckb::ElementwiseOperation elementwise_operation{ckb::ElementwiseOperation::PASS_THROUGH};
};
struct InvalidTensorOp
{
int elementwise_operation = 7; // invalid value
};
static_assert(!ckb::TensorOperatorDescriptor<InvalidTensorOp>);
struct TensorConfig
{
ckb::TensorLayout layout;
ckb::DataType data_type{ckb::DataType::UNDEFINDED};
ckb::DataType compute_type{ckb::DataType::UNDEFINDED};
};
struct ConvTensorSimple
{
TensorConfig config;
};
struct ConvTensorWithOp
{
TensorConfig config;
TensorOp operation{};
};
struct ConvTensorWithInvalidOp
{
TensorConfig config;
InvalidTensorOp operation{};
};
// Defines the signature of the convolution operation to be tested.
// This includes dimensionality, direction, data layout, and data type.
struct ConvSignature
{
int spatial_dim = 2;
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
ckb::DataType data_type = ckb::DataType::FP16;
// ckb::GroupConvDeviceOp device_operation =
// ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
int spatial_dim = 2;
ckb::DataType data_type = ckb::DataType::FP16;
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
ConvTensorSimple input = {.config = {ckb::TensorLayout::GNHWC}};
ConvTensorSimple weight = {.config = {ckb::TensorLayout::GKYXC}};
ConvTensorSimple output = {.config = {ckb::TensorLayout::GNHWK}};
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
// Compile time tests for concepts
struct ConvSignatureWithOptionalParams
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
ckb::DataType data_type = ckb::DataType::FP16;
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
int spatial_dim = 2;
ckb::DataType data_type = ckb::DataType::FP16;
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
ConvTensorWithOp input = {
.config = {ckb::TensorLayout::GNHWC, ckb::DataType::FP16},
};
ConvTensorWithOp weight = {.config = {ckb::TensorLayout::GKYXC, ckb::DataType::FP16}};
ConvTensorWithOp output = {.config = {ckb::TensorLayout::GNHWK, ckb::DataType::FP16},
.operation = {ckb::ElementwiseOperation::SCALE}};
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureWithOptionalParams>);
struct ConvSignatureWithInvalidOptionalParams
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
ckb::DataType data_type = ckb::DataType::FP16;
int elementwise_operation = 7; // this should fail
// ckb::GroupConvDeviceOp device_operation =
// ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
int spatial_dim = 2;
ckb::DataType data_type = ckb::DataType::FP16;
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
ConvTensorWithInvalidOp input = {.config = {ckb::TensorLayout::GNHWC}};
ConvTensorWithInvalidOp weight = {.config = {ckb::TensorLayout::GKYXC}};
ConvTensorWithInvalidOp output = {.config = {ckb::TensorLayout::GNHWK}};
};
static_assert(!ckb::ConvSignatureDescriptor<ConvSignatureWithInvalidOptionalParams>);
struct DefaultAlgorithm
@@ -123,7 +162,9 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
"2D Forward Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
"│ ├─ Input Layout: GNHWC\n"
"│ ├─ Weight Layout: GKYXC\n"
"│ ├─ Output Layout: GNHWK\n"
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"

View File

@@ -8,30 +8,38 @@
namespace {
using ::ck_tile::builder::factory::internal::ElementwiseOps;
using enum ::ck_tile::builder::ElementwiseOperation;
using ::ck_tile::builder::ElementwiseOperation;
using ::ck_tile::builder::factory::internal::ElementwiseOpToCK;
TEST(ConvElementwiseOp, AssignsOpsForPassThrough)
{
using Ops = ElementwiseOps<PASS_THROUGH>;
EXPECT_TRUE(
(std::is_same_v<Ops::AElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
EXPECT_TRUE(
(std::is_same_v<Ops::BElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
EXPECT_TRUE(
(std::is_same_v<Ops::CDEElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
using Op = ElementwiseOpToCK<ElementwiseOperation::PASS_THROUGH>::Op;
EXPECT_TRUE((std::is_same_v<Op, ck::tensor_operation::element_wise::PassThrough>));
}
TEST(ConvElementwiseOp, AssignsOpsForScale)
{
using Ops = ElementwiseOps<SCALE>;
using Op = ElementwiseOpToCK<ElementwiseOperation::SCALE>::Op;
EXPECT_TRUE((std::is_same_v<Op, ck::tensor_operation::element_wise::Scale>));
}
TEST(ConvElementwiseOp, AssignsOpsForClamp)
{
using Op = ElementwiseOpToCK<ElementwiseOperation::CLAMP>::Op;
EXPECT_TRUE((std::is_same_v<Op, ck::tensor_operation::element_wise::Clamp>));
}
TEST(ConvElementwiseOp, AssignsOpsForScaleAddScaleAddRelu)
{
using Op = ElementwiseOpToCK<ElementwiseOperation::SCALEADD_SCALEADD_RELU>::Op;
EXPECT_TRUE((std::is_same_v<Op, ck::tensor_operation::element_wise::ScaleAddScaleAddRelu>));
}
TEST(ConvElementwiseOp, AssignsOpsForBiasNormClamp)
{
using Op = ElementwiseOpToCK<ElementwiseOperation::BIAS_BNORM_CLAMP>::Op;
EXPECT_TRUE(
(std::is_same_v<Ops::AElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
EXPECT_TRUE(
(std::is_same_v<Ops::BElementwiseOp, ck::tensor_operation::element_wise::PassThrough>));
EXPECT_TRUE((std::is_same_v<Ops::CDEElementwiseOp, ck::tensor_operation::element_wise::Scale>));
(std::is_same_v<Op, ck::tensor_operation::element_wise::BiasNormalizeInInferClamp>));
}
} // namespace

View File

@@ -4,116 +4,481 @@
#include <gtest/gtest.h>
#include <type_traits>
// Include the helper file we're testing
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "impl/conv_signature_types.hpp"
namespace {
namespace ckb = ::ck_tile::builder;
using ::ck_tile::builder::DataType;
using ::ck_tile::builder::ElementwiseOperation;
using ::ck_tile::builder::TensorLayout;
using ::ck_tile::builder::factory::internal::AuxiliaryTensorLayouts;
using ::ck_tile::builder::factory::internal::ConvTensorLayouts;
using ::ck_tile::builder::factory::internal::GetTensorLayout;
using ::ck_tile::builder::factory::internal::LayoutToCK;
using namespace ::ck_tile::builder::test;
using enum ::ck_tile::builder::ConvDirection;
TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
{
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout1D::NWGC_GKXC_NWGK, 1, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 1,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NWGC}},
.weight = {.config = {.layout = TensorLayout::GKXC}},
.output = {.config = {.layout = TensorLayout::NWGK}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
{
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout1D::NGCW_GKXC_NGKW, 1, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 1,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NGCW}},
.weight = {.config = {.layout = TensorLayout::GKXC}},
.output = {.config = {.layout = TensorLayout::NGKW}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
{
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout1D::GNWC_GKXC_GNWK, 1, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 1,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNWC}},
.weight = {.config = {.layout = TensorLayout::GKXC}},
.output = {.config = {.layout = TensorLayout::GNWK}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
{
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout1D::NGCW_GKCX_NGKW, 1, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 1,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NGCW}},
.weight = {.config = {.layout = TensorLayout::GKCX}},
.output = {.config = {.layout = TensorLayout::NGKW}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
{
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 2,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NGCHW}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NGKHW}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
{
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 2,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
{
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 2,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNHWC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::GNHWK}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
{
using TensorLayouts = ConvTensorLayouts<ckb::GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 2,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NGCHW}},
.weight = {.config = {.layout = TensorLayout::GKCYX}},
.output = {.config = {.layout = TensorLayout::NGKHW}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
{
using TensorLayouts =
ConvTensorLayouts<ckb::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 3,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NGCDHW}},
.weight = {.config = {.layout = TensorLayout::GKCZYX}},
.output = {.config = {.layout = TensorLayout::NGKDHW}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCZYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
{
using TensorLayouts =
ConvTensorLayouts<ckb::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 3,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NDHWGC}},
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
.output = {.config = {.layout = TensorLayout::NDHWGK}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK)
{
using TensorLayouts =
ConvTensorLayouts<ckb::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, FORWARD>;
static constexpr auto sig =
ConvSignature<>{.spatial_dim = 3,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNDHWC}},
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
.output = {.config = {.layout = TensorLayout::GNDHWK}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNDHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNDHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
TEST(AuxiliaryTensorLayout, AssignsLayoutForG_K_strided)
{
using CKLayout = LayoutToCK<TensorLayout::G_K_strided>::type;
EXPECT_TRUE((std::is_same_v<CKLayout, ck::tensor_layout::convolution::G_K>));
}
TEST(AuxiliaryTensorLayout, AssignsLayoutForGC)
{
using CKLayout = LayoutToCK<TensorLayout::GC>::type;
EXPECT_TRUE((std::is_same_v<CKLayout, ck::tensor_layout::convolution::GC>));
}
TEST(AuxiliaryTensorLayout, AssignsLayoutForG_C_strided)
{
using CKLayout = LayoutToCK<TensorLayout::G_C_strided>::type;
EXPECT_TRUE((std::is_same_v<CKLayout, ck::tensor_layout::convolution::G_C>));
}
TEST(AuxiliaryTensorLayout, EmptyAuxiliaryTensorLayoutIsEmptyTuple)
{
using ::ck_tile::builder::factory::internal::EmptyAuxiliaryTensorLayout;
using EmptyLayout = EmptyAuxiliaryTensorLayout::type;
EXPECT_TRUE((std::is_same_v<EmptyLayout, ck::Tuple<>>));
}
struct MockAuxiliaryTensorConfig
{
TensorLayout layout;
};
TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout)
{
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
}
TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout)
{
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
}
TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout)
{
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_C>;
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
}
TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors)
{
static constexpr std::array<MockAuxiliaryTensorConfig, 2> aux_configs = {
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
EXPECT_EQ(AuxLayouts::Size, 2);
using ExpectedType =
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
}
TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors)
{
static constexpr std::array<MockAuxiliaryTensorConfig, 3> aux_configs = {
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC},
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
EXPECT_EQ(AuxLayouts::Size, 3);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K,
ck::tensor_layout::convolution::GC,
ck::tensor_layout::convolution::G_C>;
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
}
TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution)
{
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1, FORWARD>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
}
TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution)
{
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3, FORWARD>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
EXPECT_TRUE((std::is_same_v<AuxLayouts::type, ExpectedType>));
}
TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
{
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_K_strided}>;
static constexpr auto sig =
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
.spatial_dim = 2,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NGCHW}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NGKHW},
.operation =
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
}
TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
{
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::GC}>;
static constexpr auto sig =
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
.spatial_dim = 2,
.direction = FORWARD,
.data_type = DataType::BF16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK},
.operation =
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::GC>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
}
TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
{
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_K_strided},
TensorConfig{.layout = TensorLayout::GC}>;
static constexpr auto sig =
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
.spatial_dim = 2,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::GNHWC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::GNHWK},
.operation = OutputOp{.elementwise_operation =
ElementwiseOperation::SCALEADD_SCALEADD_RELU}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
using ExpectedDsLayout =
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
}
TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
{
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_K_strided}>;
static constexpr auto sig =
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
.spatial_dim = 1,
.direction = FORWARD,
.data_type = DataType::FP32,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NWGC}},
.weight = {.config = {.layout = TensorLayout::GKXC}},
.output = {.config = {.layout = TensorLayout::NWGK},
.operation =
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
}
TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
{
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_C_strided}>;
static constexpr auto sig =
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
.spatial_dim = 3,
.direction = FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NDHWGC}},
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
.output = {.config = {.layout = TensorLayout::NDHWGK},
.operation = OutputOp{.elementwise_operation =
ElementwiseOperation::BIAS_BNORM_CLAMP}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_C>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
}
} // namespace

View File

@@ -9,71 +9,42 @@
namespace {
namespace ckb = ck_tile::builder;
using ck_tile::builder::factory::internal::ConvTensorTypes;
using ck_tile::builder::factory::internal::DataTypeToCK;
TEST(ConvTensorType, AssignsTypesForFP16)
{
using Types = ConvTensorTypes<ckb::DataType::FP16>;
EXPECT_TRUE((std::is_same_v<Types::ADataType, ck::half_t>));
EXPECT_TRUE((std::is_same_v<Types::BDataType, ck::half_t>));
EXPECT_TRUE((std::is_same_v<Types::EDataType, ck::half_t>));
EXPECT_TRUE((std::is_same_v<Types::AccDataType, float>));
EXPECT_TRUE((std::is_same_v<Types::AComputeType, ck::half_t>));
EXPECT_TRUE((std::is_same_v<Types::BComputeType, ck::half_t>));
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, ck::half_t>));
using CKType = DataTypeToCK<ckb::DataType::FP16>::type;
EXPECT_TRUE((std::is_same_v<CKType, ck::half_t>));
}
TEST(ConvTensorType, AssignsTypesForBF16)
{
using Types = ConvTensorTypes<ckb::DataType::BF16>;
EXPECT_TRUE((std::is_same_v<Types::ADataType, ck::bhalf_t>));
EXPECT_TRUE((std::is_same_v<Types::BDataType, ck::bhalf_t>));
EXPECT_TRUE((std::is_same_v<Types::EDataType, ck::bhalf_t>));
EXPECT_TRUE((std::is_same_v<Types::AccDataType, float>));
EXPECT_TRUE((std::is_same_v<Types::AComputeType, ck::bhalf_t>));
EXPECT_TRUE((std::is_same_v<Types::BComputeType, ck::bhalf_t>));
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, ck::bhalf_t>));
using CKType = DataTypeToCK<ckb::DataType::BF16>::type;
EXPECT_TRUE((std::is_same_v<CKType, ck::bhalf_t>));
}
TEST(ConvTensorType, AssignsTypesForFP32)
{
using Types = ConvTensorTypes<ckb::DataType::FP32>;
using CKType = DataTypeToCK<ckb::DataType::FP32>::type;
EXPECT_TRUE((std::is_same_v<CKType, float>));
}
EXPECT_TRUE((std::is_same_v<Types::ADataType, float>));
EXPECT_TRUE((std::is_same_v<Types::BDataType, float>));
EXPECT_TRUE((std::is_same_v<Types::EDataType, float>));
EXPECT_TRUE((std::is_same_v<Types::AccDataType, float>));
EXPECT_TRUE((std::is_same_v<Types::AComputeType, float>));
EXPECT_TRUE((std::is_same_v<Types::BComputeType, float>));
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, float>));
TEST(ConvTensorType, AssignsTypesForINT32)
{
using CKType = DataTypeToCK<ckb::DataType::INT32>::type;
EXPECT_TRUE((std::is_same_v<CKType, int32_t>));
}
TEST(ConvTensorType, AssignsTypesForI8)
{
using Types = ConvTensorTypes<ckb::DataType::I8>;
EXPECT_TRUE((std::is_same_v<Types::ADataType, int8_t>));
EXPECT_TRUE((std::is_same_v<Types::BDataType, int8_t>));
EXPECT_TRUE((std::is_same_v<Types::EDataType, int8_t>));
EXPECT_TRUE((std::is_same_v<Types::AccDataType, int32_t>));
EXPECT_TRUE((std::is_same_v<Types::AComputeType, int8_t>));
EXPECT_TRUE((std::is_same_v<Types::BComputeType, int8_t>));
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, int8_t>));
using CKType = DataTypeToCK<ckb::DataType::I8>::type;
EXPECT_TRUE((std::is_same_v<CKType, int8_t>));
}
TEST(ConvTensorType, AssignsTypesForFP8)
{
using Types = ConvTensorTypes<ckb::DataType::FP8>;
EXPECT_TRUE((std::is_same_v<Types::ADataType, ck::f8_t>));
EXPECT_TRUE((std::is_same_v<Types::BDataType, ck::f8_t>));
EXPECT_TRUE((std::is_same_v<Types::EDataType, ck::f8_t>));
EXPECT_TRUE((std::is_same_v<Types::AccDataType, float>));
EXPECT_TRUE((std::is_same_v<Types::AComputeType, ck::f8_t>));
EXPECT_TRUE((std::is_same_v<Types::BComputeType, ck::f8_t>));
EXPECT_TRUE((std::is_same_v<Types::CShuffleDataType, ck::f8_t>));
using CKType = DataTypeToCK<ckb::DataType::FP8>::type;
EXPECT_TRUE((std::is_same_v<CKType, ck::f8_t>));
}
} // namespace

View File

@@ -178,6 +178,9 @@ constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};

View File

@@ -15,7 +15,7 @@ constexpr void run_test(const std::vector<std::string>& kernel_instance_componen
{
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
const auto kernel_string = instance.GetInstanceString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);