mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_BUILDER] Clean-up fwd conv builder implementation (#3110)
This commit is contained in:
@@ -218,6 +218,14 @@ struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale;
|
||||
};
|
||||
|
||||
// The algorithm specializations for the convolution and GEMM.
|
||||
template <typename CONV_ENUM>
|
||||
requires(
|
||||
@@ -365,6 +373,10 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V2)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
@@ -434,9 +446,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
// Check preconditions for the algorithm description.
|
||||
static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
|
||||
"Only 2D and 3D convolutions are supported in this factory.");
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseGemm<AlgorithmType>,
|
||||
|
||||
@@ -22,11 +22,22 @@ add_ck_builder_test(test_conv_builder
|
||||
test_instance_traits.cpp
|
||||
test_instance_traits_util.cpp)
|
||||
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
|
||||
|
||||
# Testing the virtual GetInstanceString methods requires kernel compilation.
|
||||
add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
|
||||
# Testing the fwd convolution builder requires kernel compilation.
|
||||
# To enable parallel compilation, the individual tests are split into separate files.
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
|
||||
function(add_ck_factory_test test_name)
|
||||
add_ck_builder_test(${test_name} ${ARGN})
|
||||
|
||||
28
experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp
Normal file
28
experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 1D BF16 (channels-first) with Pipeline V2 and FILTER_1X1_STRIDE1_PAD0 specialization and SCALE
|
||||
// elementwise op
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout1D> FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::SCALE};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V2,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
@@ -2,13 +2,11 @@
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
@@ -27,8 +25,8 @@ TEST_F(FwdConv2DBF16Test,
|
||||
}
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
@@ -45,3 +43,5 @@ TEST_F(FwdConv2DBF16Test,
|
||||
BlockGemmPipelineVersion::V5,
|
||||
ConvFwdSpecialization::FILTER_3x3>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST_F(FwdConv2DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
@@ -24,3 +22,5 @@ TEST_F(FwdConv2DFP16Test,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST_F(FwdConv2DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
@@ -24,3 +22,5 @@ TEST_F(FwdConv2DFP32Test,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
|
||||
TEST_F(FwdConv3DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
@@ -25,3 +23,5 @@ TEST_F(FwdConv3DBF16Test,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
@@ -25,3 +23,5 @@ TEST_F(FwdConv3DFP16Test,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
@@ -25,3 +23,5 @@ TEST_F(FwdConv3DFP32Test,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
|
||||
@@ -10,11 +10,6 @@ namespace ck_tile::builder::test_utils {
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
// Common test base class
|
||||
class FwdConvBuilderTestBase : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Common test implementation
|
||||
template <auto FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
@@ -93,11 +88,4 @@ constexpr void run_test()
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
// Common thread block configurations
|
||||
constexpr ThreadBlock DefaultThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock SmallThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
|
||||
Reference in New Issue
Block a user