[CK_BUILDER] Clean-up fwd conv builder implementation (#3110)

This commit is contained in:
Ville Pietilä
2025-10-29 05:37:33 +02:00
committed by GitHub
parent 515e283091
commit 13e13ce359
10 changed files with 84 additions and 48 deletions

View File

@@ -218,6 +218,14 @@ struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
};
template <>
struct ElementwiseOps<ElementwiseOperation::SCALE>
{
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale;
};
// The algorithm specializations for the convolution and GEMM.
template <typename CONV_ENUM>
requires(
@@ -365,6 +373,10 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
return ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(version == BlockGemmPipelineVersion::V2)
{
return ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(version == BlockGemmPipelineVersion::V3)
{
return ck::BlockGemmPipelineVersion::v3;
@@ -434,9 +446,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
// Check preconditions for the algorithm description.
static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
"Only 2D and 3D convolutions are supported in this factory.");
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseGemm<AlgorithmType>,

View File

@@ -22,11 +22,22 @@ add_ck_builder_test(test_conv_builder
test_instance_traits.cpp
test_instance_traits_util.cpp)
add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
# Testing the virtual GetInstanceString methods requires kernel compilation.
add_ck_builder_test(test_get_instance_string
test_get_instance_string.cpp)
add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
# Testing the fwd convolution builder requires kernel compilation.
# To enable parallel compilation, the individual tests are split into separate files.
add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_1d_bf16.cpp
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp)
function(add_ck_factory_test test_name)
add_ck_builder_test(${test_name} ${ARGN})

View File

@@ -0,0 +1,28 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
namespace ck_tile::builder::testing {
// 1D BF16 (channels-first) with Pipeline V2 and FILTER_1X1_STRIDE1_PAD0 specialization and SCALE
// elementwise op
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale)
{
constexpr ConvSignature<GroupConvLayout1D> FwdConvSignature{
.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::SCALE};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V2,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -2,13 +2,11 @@
using namespace ck_tile::builder::test_utils;
class FwdConv2DBF16Test : public FwdConvBuilderTestBase
{
};
namespace ck_tile::builder::testing {
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
TEST_F(FwdConv2DBF16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
@@ -27,8 +25,8 @@ TEST_F(FwdConv2DBF16Test,
}
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
TEST_F(FwdConv2DBF16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
@@ -45,3 +43,5 @@ TEST_F(FwdConv2DBF16Test,
BlockGemmPipelineVersion::V5,
ConvFwdSpecialization::FILTER_3x3>();
}
} // namespace ck_tile::builder::testing

View File

@@ -2,12 +2,10 @@
using namespace ck_tile::builder::test_utils;
class FwdConv2DFP16Test : public FwdConvBuilderTestBase
{
};
namespace ck_tile::builder::testing {
TEST_F(FwdConv2DFP16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
@@ -24,3 +22,5 @@ TEST_F(FwdConv2DFP16Test,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -2,12 +2,10 @@
using namespace ck_tile::builder::test_utils;
class FwdConv2DFP32Test : public FwdConvBuilderTestBase
{
};
namespace ck_tile::builder::testing {
TEST_F(FwdConv2DFP32Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
@@ -24,3 +22,5 @@ TEST_F(FwdConv2DFP32Test,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -2,13 +2,11 @@
using namespace ck_tile::builder::test_utils;
class FwdConv3DBF16Test : public FwdConvBuilderTestBase
{
};
namespace ck_tile::builder::testing {
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
TEST_F(FwdConv3DBF16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
.spatial_dim = 3,
@@ -25,3 +23,5 @@ TEST_F(FwdConv3DBF16Test,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::DEFAULT>();
}
} // namespace ck_tile::builder::testing

View File

@@ -2,13 +2,11 @@
using namespace ck_tile::builder::test_utils;
class FwdConv3DFP16Test : public FwdConvBuilderTestBase
{
};
namespace ck_tile::builder::testing {
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
TEST_F(FwdConv3DFP16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
.spatial_dim = 3,
@@ -25,3 +23,5 @@ TEST_F(FwdConv3DFP16Test,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -2,13 +2,11 @@
using namespace ck_tile::builder::test_utils;
class FwdConv3DFP32Test : public FwdConvBuilderTestBase
{
};
namespace ck_tile::builder::testing {
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
TEST_F(FwdConv3DFP32Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
.spatial_dim = 3,
@@ -25,3 +23,5 @@ TEST_F(FwdConv3DFP32Test,
BlockGemmPipelineVersion::V1,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -10,11 +10,6 @@ namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
// Common test base class
class FwdConvBuilderTestBase : public ::testing::Test
{
};
// Common test implementation
template <auto FwdConvSignature,
ThreadBlock FwdThreadBlock,
@@ -93,11 +88,4 @@ constexpr void run_test()
EXPECT_NE(invoker_ptr, nullptr);
}
// Common thread block configurations
constexpr ThreadBlock DefaultThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr ThreadBlock SmallThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
} // namespace ck_tile::builder::test_utils