diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index ba2087cfa3..de8ba4f648 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -218,6 +218,14 @@ struct ElementwiseOps using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough; }; +template <> +struct ElementwiseOps +{ + using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough; + using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough; + using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale; +}; + // The algorithm specializations for the convolution and GEMM. template requires( @@ -365,6 +373,10 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { return ck::BlockGemmPipelineVersion::v1; } + else if constexpr(version == BlockGemmPipelineVersion::V2) + { + return ck::BlockGemmPipelineVersion::v2; + } else if constexpr(version == BlockGemmPipelineVersion::V3) { return ck::BlockGemmPipelineVersion::v3; @@ -434,9 +446,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - // Check preconditions for the algorithm description. - static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3, - "Only 2D and 3D convolutions are supported in this factory."); static_assert(SpecifiesThreadBlock, "The convolution algorithm descriptor must specify thread block info."); static_assert(SpecifiesGridwiseGemm, diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index bafa95862a..3a13f7239f 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -22,11 +22,22 @@ add_ck_builder_test(test_conv_builder test_instance_traits.cpp test_instance_traits_util.cpp) +add_ck_builder_test(test_inline_diff test_inline_diff.cpp) + # Testing the virtual GetInstanceString methods requires kernel compilation. add_ck_builder_test(test_get_instance_string test_get_instance_string.cpp) -add_ck_builder_test(test_inline_diff test_inline_diff.cpp) +# Testing the fwd convolution builder requires kernel compilation. +# To enable parallel compilation, the individual tests are split into separate files. +add_ck_builder_test(test_ckb_build_fwd_instances + conv/test_ckb_conv_fwd_1d_bf16.cpp + conv/test_ckb_conv_fwd_2d_bf16.cpp + conv/test_ckb_conv_fwd_2d_fp16.cpp + conv/test_ckb_conv_fwd_2d_fp32.cpp + conv/test_ckb_conv_fwd_3d_bf16.cpp + conv/test_ckb_conv_fwd_3d_fp16.cpp + conv/test_ckb_conv_fwd_3d_fp32.cpp) function(add_ck_factory_test test_name) add_ck_builder_test(${test_name} ${ARGN}) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp new file mode 100644 index 0000000000..d5b8802896 --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -0,0 +1,28 @@ +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +// 1D BF16 (channels-first) with Pipeline V2 and FILTER_1X1_STRIDE1_PAD0 specialization and SCALE +// elementwise op +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, + .data_type = DataType::BF16, + .elementwise_operation = ElementwiseOperation::SCALE}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; + + run_test(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 433b39884b..77c5c80489 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -2,13 +2,11 @@ using namespace ck_tile::builder::test_utils; -class FwdConv2DBF16Test : public FwdConvBuilderTestBase -{ -}; +namespace ck_tile::builder::testing { // 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT -TEST_F(FwdConv2DBF16Test, - Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, @@ -27,8 +25,8 @@ TEST_F(FwdConv2DBF16Test, } // 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3 -TEST_F(FwdConv2DBF16Test, - Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, @@ -45,3 +43,5 @@ TEST_F(FwdConv2DBF16Test, BlockGemmPipelineVersion::V5, ConvFwdSpecialization::FILTER_3x3>(); } + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index 2b2109a141..c81d7543bb 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -2,12 +2,10 @@ using namespace ck_tile::builder::test_utils; -class FwdConv2DFP16Test : public FwdConvBuilderTestBase -{ -}; +namespace ck_tile::builder::testing { -TEST_F(FwdConv2DFP16Test, - Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, @@ -24,3 +22,5 @@ TEST_F(FwdConv2DFP16Test, BlockGemmPipelineVersion::V3, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index 3eade37659..d55a120bb8 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -2,12 +2,10 @@ using namespace ck_tile::builder::test_utils; -class FwdConv2DFP32Test : public FwdConvBuilderTestBase -{ -}; +namespace ck_tile::builder::testing { -TEST_F(FwdConv2DFP32Test, - Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, @@ -24,3 +22,5 @@ TEST_F(FwdConv2DFP32Test, BlockGemmPipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index 6bc62153cd..f7bcf49e54 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -2,13 +2,11 @@ using namespace ck_tile::builder::test_utils; -class FwdConv3DBF16Test : public FwdConvBuilderTestBase -{ -}; +namespace ck_tile::builder::testing { // 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT -TEST_F(FwdConv3DBF16Test, - Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, @@ -25,3 +23,5 @@ TEST_F(FwdConv3DBF16Test, BlockGemmPipelineVersion::V3, ConvFwdSpecialization::DEFAULT>(); } + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index c23e58c702..27b5ddc821 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -2,13 +2,11 @@ using namespace ck_tile::builder::test_utils; -class FwdConv3DFP16Test : public FwdConvBuilderTestBase -{ -}; +namespace ck_tile::builder::testing { // 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0 -TEST_F(FwdConv3DFP16Test, - Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, @@ -25,3 +23,5 @@ TEST_F(FwdConv3DFP16Test, BlockGemmPipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index deaf2038e2..c0b6f04383 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -2,13 +2,11 @@ using namespace ck_tile::builder::test_utils; -class FwdConv3DFP32Test : public FwdConvBuilderTestBase -{ -}; +namespace ck_tile::builder::testing { // 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0 -TEST_F(FwdConv3DFP32Test, - Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, @@ -25,3 +23,5 @@ TEST_F(FwdConv3DFP32Test, BlockGemmPipelineVersion::V1, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index 37ee3a953a..7ad01bd922 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -10,11 +10,6 @@ namespace ck_tile::builder::test_utils { using namespace ck_tile::builder; using namespace test; -// Common test base class -class FwdConvBuilderTestBase : public ::testing::Test -{ -}; - // Common test implementation template