diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index de8ba4f648..252d423716 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -158,6 +158,28 @@ struct ConvTensorLayouts +consteval auto GetTensorLayout() +{ + + if constexpr(SPATIAL_DIM == 1) + { + return factory_internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 2) + { + return factory_internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 3) + { + return factory_internal::ConvTensorLayouts{}; + } + else + { + static_assert(false, "Unsupported spatial dimension for convolution layout."); + } +} + // Type mappings from builder convolution data type to CK tensor types. template struct ConvTensorTypes @@ -440,8 +462,8 @@ template { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = - factory_internal::ConvTensorLayouts; + /*static constexpr auto*/ + using Layouts = decltype(factory_internal::GetTensorLayout()); using Types = factory_internal::ConvTensorTypes; using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 0851f0061e..76e5590ad6 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -17,6 +17,7 @@ // signature at compile time. #pragma once +#include #include #include @@ -40,16 +41,21 @@ template concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); +template +concept ConvDeviceOp = std::same_as, GroupConvDeviceOp>; + +template +concept ConvLayout = std::same_as, GroupConvLayout>; + // Concept for a type that defines a convolution's operational signature. template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; { t.direction } -> std::convertible_to; - requires std::convertible_to || - std::convertible_to || - std::convertible_to; + { t.layout } -> ConvLayout; { t.data_type } -> std::convertible_to; { t.elementwise_operation } -> std::convertible_to; + { t.device_operation } -> ConvDeviceOp; }; // Concept to validate a convolution signature's values. diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 7f49e77f81..509f240edd 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -48,6 +48,18 @@ enum class GroupConvLayout3D NGCDHW_GKCZYX_NGKDHW, }; +struct GroupConvLayout { + union { + GroupConvLayout1D _1d; + GroupConvLayout2D _2d; + GroupConvLayout3D _3d; + }; + + constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {} + constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {} + constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {} +}; + // Direction of the convolution operation. enum class ConvDirection { @@ -56,6 +68,50 @@ enum class ConvDirection BACKWARD_WEIGHT }; +// Forward convolution device operations. +enum class FwdGroupConvDeviceOperation +{ + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +}; + +// Backward data convolution device operations. +enum class BwdDataGroupConvDeviceOperation +{ + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, + DeviceGroupedConvBwdDataMultipleD, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle +}; + +// Backward weight convolution device operations. +enum class BwdWeightGroupConvDeviceOperation +{ + DeviceGroupedConvBwdWeight, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, + DeviceGroupedConvBwdWeight_Xdl_CShuffle, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3, + DeviceGroupedConvBwdWeightMultipleD, + DeviceGroupedConvBwdWeight_Dl +}; + +// Structural type for device operation +struct GroupConvDeviceOp { + union { + FwdGroupConvDeviceOperation _fwd; + BwdDataGroupConvDeviceOperation _bwd_data; + BwdWeightGroupConvDeviceOperation _bwd_weight; + }; + + constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {} + constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {} + constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {} +}; + // Fused element-wise operations. enum class ElementwiseOperation { diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index d5b8802896..b660fa3303 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -9,12 +9,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 1, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::SCALE}; + .elementwise_operation = ElementwiseOperation::SCALE, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 77c5c80489..cf942f56a1 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -28,12 +30,14 @@ TEST(FwdConvInstances, TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index c81d7543bb..efd3ecc680 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -7,12 +7,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index d55a120bb8..a7248d25b5 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -7,12 +7,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index f7bcf49e54..b8c8bc7063 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 27b5ddc821..035a9df36d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index c0b6f04383..2713dd1b01 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + }; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index 297f827395..cc5490c711 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -3,11 +3,13 @@ #pragma once +#include #include "ck_tile/builder/conv_signature_concepts.hpp" namespace ck_tile::builder::test { -template +using namespace ck_tile::builder; + struct ConvSignature { int spatial_dim; @@ -15,9 +17,8 @@ struct ConvSignature GroupConvLayout layout; DataType data_type; ElementwiseOperation elementwise_operation; + GroupConvDeviceOp device_operation; }; -static_assert(ConvSignatureDescriptor>); -static_assert(ConvSignatureDescriptor>); -static_assert(ConvSignatureDescriptor>); +static_assert(ConvSignatureDescriptor); } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index 7ad01bd922..cd3943d26f 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -11,7 +11,7 @@ using namespace ck_tile::builder; using namespace test; // Common test implementation -template