[CK_BUILDER] Add conv factories for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle and DeviceGroupedConvFwdMultipleD_Wmma_CShuffle (#3138)

* Add device operation to conv signature. Use unions to hold conv layouts and device operations.

* Add predicates for all device op instances.

* Use the device op signature for validation.

* Fix ckb CMakeLists.txt file for tests.

* Fix building CK Builder instance traits after the introduction of direct load template parameter in CK.

* Fix clang-formatting.

* Add factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device op.

* Add conv factory for  DeviceGroupedConvFwdMultipleD_Wmma_CShuffle

* Rename elements per wave per shuffle member in the epilogue concept.

* clang-format

* Add concepts and types for optional device op template parameters.

* Add optional compute, direct load, and loop scheduler arguments to conv factory.

* Add number of groups to merge template parameter.

* clang-format.
This commit is contained in:
Ville Pietilä
2025-11-03 09:03:25 +02:00
committed by GitHub
parent 16e85cf179
commit 3ae3992c18
16 changed files with 986 additions and 168 deletions

View File

@@ -21,10 +21,11 @@ TEST(FwdConvInstances,
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V2,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V2,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -0,0 +1,28 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
namespace ck_tile::builder::testing {
// 1D FP16 (channels-last) with DEFAULT specialization
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst_scale)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::NWGC_GKXC_NWGK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle};
constexpr ThreadBlock FwdThreadBlock{.block_size = 64,
.tile_size = {.m = 64, .n = 32, .k = 32}};
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
} // namespace ck_tile::builder::testing

View File

@@ -0,0 +1,28 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
namespace ck_tile::builder::testing {
// 1D I8 (channels-last) with and DEFAULT specialization
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::GNWC_GKXC_GNWK,
.data_type = DataType::I8,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle};
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
.tile_size = {.m = 64, .n = 64, .k = 64}};
run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
} // namespace ck_tile::builder::testing

View File

@@ -20,10 +20,10 @@ TEST(FwdConvInstances,
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
ConvFwdSpecialization::DEFAULT>();
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
ConvFwdSpecialization::DEFAULT>();
}
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
@@ -42,10 +42,10 @@ TEST(FwdConvInstances,
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V5,
ConvFwdSpecialization::FILTER_3x3>();
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V5,
ConvFwdSpecialization::FILTER_3x3>();
}
} // namespace ck_tile::builder::testing

View File

@@ -19,10 +19,11 @@ TEST(FwdConvInstances,
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -19,10 +19,11 @@ TEST(FwdConvInstances,
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -20,10 +20,10 @@ TEST(FwdConvInstances,
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::DEFAULT>();
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::DEFAULT>();
}
} // namespace ck_tile::builder::testing

View File

@@ -20,10 +20,11 @@ TEST(FwdConvInstances,
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -20,10 +20,11 @@ TEST(FwdConvInstances,
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing