From 651dc5343b435e722fdfe06ef90e4d9b92761c40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Mon, 3 Nov 2025 09:03:25 +0200 Subject: [PATCH] [CK_BUILDER] Add conv factories for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle and DeviceGroupedConvFwdMultipleD_Wmma_CShuffle (#3138) * Add device operation to conv signature. Use unions to hold conv layouts and device operations. * Add predicates for all device op instances. * Use the device op signature for validation. * Fix ckb CMakeLists.txt file for tests. * Fix building CK Builder instance traits after the introduction of direct load template parameter in CK. * Fix clang-formatting. * Add factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device op. * Add conv factory for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle * Rename elements per wave per shuffle member in the epilogue concept. * clang-format * Add concepts and types for optional device op template parameters. * Add optional compute, direct load, and loop scheduler arguments to conv factory. * Add number of groups to merge template parameter. * clang-format. [ROCm/composable_kernel commit: 3ae3992c18045446f1b733b306265efbd14c5d57] --- .../builder/conv_algorithm_concepts.hpp | 67 +- .../ck_tile/builder/conv_algorithm_limits.hpp | 4 +- .../include/ck_tile/builder/conv_factory.hpp | 600 +++++++++++++++--- .../builder/include/ck_tile/builder/types.hpp | 45 ++ experimental/builder/test/CMakeLists.txt | 4 +- .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 9 +- .../test/conv/test_ckb_conv_fwd_1d_fp16.cpp | 28 + .../test/conv/test_ckb_conv_fwd_1d_i8.cpp | 28 + .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 16 +- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 9 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 9 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 8 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 9 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 9 +- .../test/impl/conv_algorithm_types.hpp | 132 +++- .../test/utils/ckb_conv_test_common.hpp | 177 +++++- 16 files changed, 986 insertions(+), 168 deletions(-) create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 078c066e55..586a119c75 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -24,9 +24,9 @@ concept ThreadBlockDescriptor = requires(T t) { { t.tile_size.k } -> std::convertible_to; }; -// Concept for parameters that describe a gridwise GEMM problem. +// Concept for parameters that describe a gridwise XDL GEMM problem. template -concept GridwiseGemmDescriptor = requires(T t) { +concept GridwiseXdlGemmDescriptor = requires(T t) { { t.ak1 } -> std::convertible_to; { t.bk1 } -> std::convertible_to; { t.m_per_xdl } -> std::convertible_to; @@ -35,6 +35,24 @@ concept GridwiseGemmDescriptor = requires(T t) { { t.n_xdl_per_wave } -> std::convertible_to; }; +// Concept for parameter that describe block GEMM problem. +template +concept BlockGemmDescriptor = requires(T t) { + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; +}; + +// Concept for parameters that describe a gridwise WMMA GEMM problem. +template +concept GridwiseWmmaGemmDescriptor = requires(T t) { + { t.k1 } -> std::convertible_to; + { t.m_per_wmma } -> std::convertible_to; + { t.n_per_wmma } -> std::convertible_to; + { t.m_wmma_per_wave } -> std::convertible_to; + { t.n_wmma_per_wave } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; +}; + // Concept for vectorized data transfer for convolution input tensors. template concept BlockTransferDescriptor = requires(T t) { @@ -66,8 +84,8 @@ concept LdsTransferDescriptor = requires(T t) { // LDS). template concept EpilogueDescriptor = requires(T t) { - { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to; - { t.n_xdl_per_wave_per_shuffle } -> std::convertible_to; + { t.m_per_wave_per_shuffle } -> std::convertible_to; + { t.n_per_wave_per_shuffle } -> std::convertible_to; { t.scalar_per_vector } -> std::convertible_to; }; @@ -77,7 +95,7 @@ concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; -// No requirements yet for a ConvAlogorithm concept. +// No requirements yet for a ConvAlgorithm concept. template concept ConvAlgorithmDescriptor = std::is_class_v; @@ -91,10 +109,16 @@ concept SpecifiesThreadBlock = requires { { T::thread_block } -> ThreadBlockDescriptor; }; -// Concept to check if a struct specifies gridwise GEMM info. +// Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseGemm = requires { - { T::gridwise_gemm } -> GridwiseGemmDescriptor; +concept SpecifiesGridwiseXdlGemm = requires { + { T::gridwise_gemm } -> GridwiseXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise WMMA GEMM info. +template +concept SpecifiesGridwiseWmmaGemm = requires { + { T::gridwise_gemm } -> GridwiseWmmaGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. @@ -127,10 +151,11 @@ concept SpecifiesSourceAccessOrder = requires(T t) { { T::block_transfer.src_access_order_b } -> AccessOrderDescriptor; }; -// Concept to check if struct specifies block_gemm_pipeline_version. +// Concept to check if struct specifies block GEMM. template -concept SpecifiesGemmPipelineVersion = requires { - { T::pipeline_version } -> std::convertible_to; +concept SpecifiesBlockGemm = requires { + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; }; template @@ -138,4 +163,24 @@ concept SpecifiesFwdConcSpecialization = requires { { T::fwd_specialization } -> std::convertible_to; }; +template +concept SpecifiesGemmSpecialization = requires { + { T::gemm_specialization } -> std::convertible_to; +}; + +template +concept SpecifiesNumPrefetchStages = requires { + { T::num_gemm_k_prefetch_stages } -> std::convertible_to; +}; + +template +concept SpecifiesNumGroupsToMerge = requires { + { T::num_groups_to_merge } -> std::convertible_to; +}; + +template +concept SpecifiesLoopScheduler = requires { + { T::loop_scheduler } -> std::convertible_to; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 7ef8930273..68d5ec5a83 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -18,8 +18,8 @@ concept InputVectorTransferLimits = requires { // Limits for output vector transfer. template concept OutputVectorTransferLimits = requires { - requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 && - Value.n_xdl_per_wave_per_shuffle > 0; + requires Value.scalar_per_vector > 0 && Value.m_per_wave_per_shuffle > 0 && + Value.n_per_wave_per_shuffle > 0; }; // Limits for access order. Must be a permutation of {0, 1, 2}. diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 31be8c322c..8ea3e18d65 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -36,6 +36,8 @@ #pragma once +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" @@ -194,7 +196,9 @@ template <> struct ConvTensorTypes { using ADataType = ck::half_t; + using AComputeType = ck::half_t; using BDataType = ck::half_t; + using BComputeType = ck::half_t; using CShuffleDataType = ck::half_t; using DsDataTypes = ck::Tuple<>; using AccDataType = float; @@ -205,7 +209,9 @@ template <> struct ConvTensorTypes { using ADataType = ck::bhalf_t; + using AComputeType = ck::bhalf_t; using BDataType = ck::bhalf_t; + using BComputeType = ck::bhalf_t; using CShuffleDataType = ck::bhalf_t; using DsDataTypes = ck::Tuple<>; using AccDataType = float; @@ -216,13 +222,28 @@ template <> struct ConvTensorTypes { using ADataType = float; + using AComputeType = float; using BDataType = float; + using BComputeType = float; using CShuffleDataType = float; using DsDataTypes = ck::Tuple<>; using AccDataType = float; using EDataType = float; }; +template <> +struct ConvTensorTypes +{ + using ADataType = int8_t; + using AComputeType = int8_t; + using BDataType = int8_t; + using BComputeType = int8_t; + using CShuffleDataType = int8_t; + using DsDataTypes = ck::Tuple<>; + using AccDataType = int32_t; + using EDataType = int8_t; +}; + template struct ElementwiseOps { @@ -262,6 +283,61 @@ struct ConvSpec template ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec; +struct BlockGemmSpec +{ + ck::BlockGemmPipelineVersion pipeline_version; + ck::BlockGemmPipelineScheduler scheduler; +}; + +template +constexpr BlockGemmSpec SetBlockGemm() +{ + constexpr auto& BG = ALGORITHM.block_gemm; + + ck::BlockGemmPipelineScheduler scheduler; + ck::BlockGemmPipelineVersion version; + + if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE) + { + scheduler = ck::BlockGemmPipelineScheduler::Intrawave; + } + else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE) + { + scheduler = ck::BlockGemmPipelineScheduler::Interwave; + } + else + { + static_assert(false, "Unknown BlockGemmPipelineScheduler"); + } + + if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1) + { + version = ck::BlockGemmPipelineVersion::v1; + } + else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2) + { + version = ck::BlockGemmPipelineVersion::v2; + } + else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3) + { + version = ck::BlockGemmPipelineVersion::v3; + } + else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4) + { + version = ck::BlockGemmPipelineVersion::v4; + } + else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5) + { + version = ck::BlockGemmPipelineVersion::v5; + } + else + { + static_assert(false, "Unknown BlockGemmPipelineVersion"); + } + + return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; +} + // Block info for a convolution. struct MNK { @@ -283,31 +359,6 @@ constexpr ConvBlock SetThreadBlockInfo() .per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}}; } -// Convolution tuning parameters. -struct GridwiseGemm -{ - size_t ak1 = 0; - size_t bk1 = 0; - size_t m_per_xdl = 0; - size_t n_per_xdl = 0; - size_t m_xdl_per_wave = 0; - size_t n_xdl_per_wave = 0; -}; - -template -constexpr GridwiseGemm SetGridwiseGemmInfo() -{ - constexpr auto& TP = ALGORITHM.gridwise_gemm; - return GridwiseGemm{ - .ak1 = TP.ak1, - .bk1 = TP.bk1, - .m_per_xdl = TP.m_per_xdl, - .n_per_xdl = TP.n_per_xdl, - .m_xdl_per_wave = TP.m_xdl_per_wave, - .n_xdl_per_wave = TP.n_xdl_per_wave, - }; -} - // Block transfer parameters for A or B tensor. struct BlockTransfer { @@ -362,8 +413,8 @@ constexpr BlockTransfer SetFwdConvBBlockTransfer() // Block transfer parameters for C tensor. struct CBlockTransfer { - size_t m_xdl_per_wave_per_shuffle = 0; - size_t n_xdl_per_wave_per_shuffle = 0; + size_t m_per_wave_per_shuffle = 0; + size_t n_per_wave_per_shuffle = 0; ck::Array thread_cluster_dims = {0, 0, 0, 0}; size_t scalar_per_vector = 0; }; @@ -373,8 +424,8 @@ constexpr CBlockTransfer SetCBlockTransfer() { constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c; constexpr auto& EPC = ALGORITHM.block_transfer.epilogue_c; - CBlockTransfer block_transfer{.m_xdl_per_wave_per_shuffle = EPC.m_xdl_per_wave_per_shuffle, - .n_xdl_per_wave_per_shuffle = EPC.n_xdl_per_wave_per_shuffle, + CBlockTransfer block_transfer{.m_per_wave_per_shuffle = EPC.m_per_wave_per_shuffle, + .n_per_wave_per_shuffle = EPC.n_per_wave_per_shuffle, .thread_cluster_dims = { TCL.m_block, @@ -386,6 +437,130 @@ constexpr CBlockTransfer SetCBlockTransfer() return block_transfer; } +template +consteval ck::LoopScheduler SetLoopScheduler() +{ + constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; + + if constexpr(loop_scheduler == LoopScheduler::DEFAULT) + { + return ck::LoopScheduler::Default; + } + else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE) + { + return ck::LoopScheduler::Interwave; + } + else + { + static_assert(false, "Unknown LoopScheduler"); + } +} + +template +consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() +{ + constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; + if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1) + { + return ck::PipelineVersion::v1; + } + else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2) + { + return ck::PipelineVersion::v2; + } + else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3) + { + static_assert(false, "V3 is used only for stream-K."); + } + else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4) + { + return ck::PipelineVersion::v4; + } + else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY) + { + return ck::PipelineVersion::weight_only; + } + else + { + static_assert(false, "Unknown GridwiseGemmPipelineVersion"); + } +} + +template +consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization() +{ + constexpr auto gemm_spec = ALGORITHM.gemm_specialization; + + if constexpr(gemm_spec == GemmSpecialization::Default) + { + return ck::tensor_operation::device::GemmSpecialization::Default; + } + else if constexpr(gemm_spec == GemmSpecialization::MPadding) + { + return ck::tensor_operation::device::GemmSpecialization::MPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::NPadding) + { + return ck::tensor_operation::device::GemmSpecialization::NPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::KPadding) + { + return ck::tensor_operation::device::GemmSpecialization::KPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::MNPadding) + { + return ck::tensor_operation::device::GemmSpecialization::MNPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::MKPadding) + { + return ck::tensor_operation::device::GemmSpecialization::MKPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::NKPadding) + { + return ck::tensor_operation::device::GemmSpecialization::NKPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::MNKPadding) + { + return ck::tensor_operation::device::GemmSpecialization::MNKPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::OPadding) + { + return ck::tensor_operation::device::GemmSpecialization::OPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::MOPadding) + { + return ck::tensor_operation::device::GemmSpecialization::MOPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::NOPadding) + { + return ck::tensor_operation::device::GemmSpecialization::NOPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::KOPadding) + { + return ck::tensor_operation::device::GemmSpecialization::KOPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::MNOPadding) + { + return ck::tensor_operation::device::GemmSpecialization::MNOPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::MKOPadding) + { + return ck::tensor_operation::device::GemmSpecialization::MKOPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::NKOPadding) + { + return ck::tensor_operation::device::GemmSpecialization::NKOPadding; + } + else if constexpr(gemm_spec == GemmSpecialization::MNKOPadding) + { + return ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + } + else + { + static_assert(false, "Unknown GemmSpecialization"); + } +} + template consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { @@ -473,7 +648,7 @@ struct ConvFactory static_assert(SpecifiesThreadBlock, "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseGemm, + static_assert(SpecifiesGridwiseXdlGemm, "The convolution algorithm descriptor must specify gridwise GEMM info."); static_assert(SpecifiesBlockTransfer, "The convolution algorithm descriptor must specify block transfer info."); @@ -484,30 +659,34 @@ struct ConvFactory "The convolution algorithm descriptor must specify thread cluster access order info."); static_assert(SpecifiesSourceAccessOrder, "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesGemmPipelineVersion, - "The convolution algorithm descriptor must specify block gemm pipeline version."); + static_assert(SpecifiesBlockGemm, + "The convolution algorithm descriptor must specify block gemm pipeline."); static_assert(SpecifiesFwdConcSpecialization, "The convolution algorithm descriptor must specify forward convolution " "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load == + ALGORITHM.block_transfer.lds_transfer_b.is_direct_load, + "A and B block transfers must both be direct load or not."); + static constexpr bool IS_DIRECT_LOAD = ALGORITHM.block_transfer.lds_transfer_a.is_direct_load; static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); - static constexpr factory_internal::ConvSpec SPECIALIZATION{ - .conv_spec = FWD_CONV_SPECIALIZATION, - .gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding, - }; - static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = - factory_internal::SetGridwiseGemmInfo(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = factory_internal::SetFwdConvABlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = factory_internal::SetFwdConvBBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); - static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave; - static constexpr auto PIPELINE_VERSION = - factory_internal::SetBlockGemmPipelineVersion(); + static constexpr auto BLOCK_GEMM = factory_internal::SetBlockGemm(); // Check limits for the algorithm parameters. // TODO: Add more limits checks as needed. @@ -520,54 +699,295 @@ struct ConvFactory static_assert(AccessOrderLimits); // The forward convolution kernel class instance. - using Instance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< // - SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, - typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, - typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, - SPECIALIZATION.conv_spec, - SPECIALIZATION.gemm_spec, - BLOCK.block_size, - BLOCK.per_block.m, - BLOCK.per_block.n, - BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, - to_sequence_v, - to_sequence_v, - to_sequence_v, - A_BLOCK_TRANSFER.src_vector_dim, - A_BLOCK_TRANSFER.src_scalar_per_vector, - A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - A_BLOCK_TRANSFER.lds_padding, - to_sequence_v, - to_sequence_v, - to_sequence_v, - B_BLOCK_TRANSFER.src_vector_dim, - B_BLOCK_TRANSFER.src_scalar_per_vector, - B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - B_BLOCK_TRANSFER.lds_padding, - C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, - C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, - to_sequence_v, - C_BLOCK_TRANSFER.scalar_per_vector, - PIPELINE_SCHEDULER, - PIPELINE_VERSION>; + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::AComputeType, + typename Types::BComputeType, + IS_DIRECT_LOAD>; +}; + +// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesGridwiseXdlGemm, + "The convolution algorithm descriptor must specify gridwise GEMM info."); + static_assert(SpecifiesBlockTransfer, + "The convolution algorithm descriptor must specify block transfer info."); + static_assert(SpecifiesLdsTransfer, + "The convolution algorithm descriptor must specify LDS transfer info."); + static_assert( + SpecifiesThreadClusterAccessOrder, + "The convolution algorithm descriptor must specify thread cluster access order info."); + static_assert(SpecifiesSourceAccessOrder, + "The convolution algorithm descriptor must specify source access order info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(SpecifiesNumPrefetchStages, + "The convolution algorithm descriptor must specify number of prefetch stages."); + static_assert(SpecifiesLoopScheduler, + "The convolution algorithm descriptor must specify loop scheduler."); + static_assert(SpecifiesNumGroupsToMerge, + "The convolution algorithm descriptor must specify number of groups to merge."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + factory_internal::SetFwdConvABlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + factory_internal::SetFwdConvBBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = + factory_internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::AComputeType, + typename Types::BComputeType, + LOOP_SCHEDULER, + ALGORITHM.num_groups_to_merge>; +}; + +// Factory specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesGridwiseWmmaGemm, + "The convolution algorithm descriptor must specify gridwise GEMM info."); + static_assert(SpecifiesBlockTransfer, + "The convolution algorithm descriptor must specify block transfer info."); + static_assert(SpecifiesLdsTransfer, + "The convolution algorithm descriptor must specify LDS transfer info."); + static_assert( + SpecifiesThreadClusterAccessOrder, + "The convolution algorithm descriptor must specify thread cluster access order info."); + static_assert(SpecifiesSourceAccessOrder, + "The convolution algorithm descriptor must specify source access order info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesNumPrefetchStages, + "The convolution algorithm descriptor must specify number of prefetch stages."); + static_assert(SpecifiesLoopScheduler, + "The convolution algorithm descriptor must specify loop scheduler."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = + factory_internal::SetGridwiseGemmPipelineVersion(); + static constexpr auto A_BLOCK_TRANSFER = + factory_internal::SetFwdConvABlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + factory_internal::SetFwdConvBBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = + factory_internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + LOOP_SCHEDULER, + GRIDWISE_GEMM_PIPELINE_VERSION>; }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 47bd8327d4..a2ef89da2e 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -138,6 +138,45 @@ enum class BlockGemmPipelineVersion V5 }; +enum struct BlockGemmPipelineScheduler +{ + INTRAWAVE, + INTERWAVE, +}; + +// Enums for the gridwise GEMM pipeline versions. +enum class GridwiseGemmPipelineVersion +{ + V1, + V2, + V3, // Only used in stream-K implementation + V4, + WEIGHT_ONLY +}; + +// Enums for the GEMM specialization. +enum struct GemmSpecialization +{ + // Gemm + Default, + MPadding, + NPadding, + KPadding, + MNPadding, + MKPadding, + NKPadding, + MNKPadding, + // Gemm + Gemm + OPadding, + MOPadding, + NOPadding, + KOPadding, + MNOPadding, + MKOPadding, + NKOPadding, + MNKOPadding +}; + // Enums for the forward convolution specialization. enum class ConvFwdSpecialization { @@ -147,4 +186,10 @@ enum class ConvFwdSpecialization FILTER_3x3 }; +enum class LoopScheduler +{ + DEFAULT, + INTERWAVE +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 66ecf32197..26a666a805 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -35,7 +35,9 @@ add_ck_builder_test(test_ckb_get_instance_string # Testing the fwd convolution builder requires kernel compilation. # To enable parallel compilation, the individual tests are split into separate files. add_ck_builder_test(test_ckb_build_fwd_instances - conv/test_ckb_conv_fwd_1d_bf16.cpp + conv/test_ckb_conv_fwd_1d_fp16.cpp + conv/test_ckb_conv_fwd_1d_bf16.cpp + conv/test_ckb_conv_fwd_1d_i8.cpp conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index 77ff0fe28f..472c43438d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -21,10 +21,11 @@ TEST(FwdConvInstances, constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; - run_test(); + run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + FwdConvSignature, + FwdThreadBlock, + BlockGemmPipelineVersion::V2, + ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } } // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp new file mode 100644 index 0000000000..3f840ba2b0 --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp @@ -0,0 +1,28 @@ +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +// 1D FP16 (channels-last) with DEFAULT specialization +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst_scale) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout1D::NWGC_GKXC_NWGK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 64, + .tile_size = {.m = 64, .n = 32, .k = 32}}; + + run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp new file mode 100644 index 0000000000..1819cca728 --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp @@ -0,0 +1,28 @@ +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +// 1D I8 (channels-last) with and DEFAULT specialization +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout1D::GNWC_GKXC_GNWK, + .data_type = DataType::I8, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 128, + .tile_size = {.m = 64, .n = 64, .k = 64}}; + + run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 5be7d5e604..b9969f7e95 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -20,10 +20,10 @@ TEST(FwdConvInstances, constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; - run_test(); + run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } // 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3 @@ -42,10 +42,10 @@ TEST(FwdConvInstances, constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; - run_test(); + run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } } // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index 4abe3df40d..cd5186cc10 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -19,10 +19,11 @@ TEST(FwdConvInstances, constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; - run_test(); + run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + FwdConvSignature, + FwdThreadBlock, + BlockGemmPipelineVersion::V3, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); } } // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index 5ea804cf8b..584e0ab182 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -19,10 +19,11 @@ TEST(FwdConvInstances, constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; - run_test(); + run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + FwdConvSignature, + FwdThreadBlock, + BlockGemmPipelineVersion::V4, + ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } } // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index c729148346..17caf98457 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -20,10 +20,10 @@ TEST(FwdConvInstances, constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; - run_test(); + run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } } // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 832acd7412..ec4649a6ff 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -20,10 +20,11 @@ TEST(FwdConvInstances, constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; - run_test(); + run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + FwdConvSignature, + FwdThreadBlock, + BlockGemmPipelineVersion::V4, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); } } // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 9d0e107dbc..393ea9206d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -20,10 +20,11 @@ TEST(FwdConvInstances, constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; - run_test(); + run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + FwdConvSignature, + FwdThreadBlock, + BlockGemmPipelineVersion::V1, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); } } // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 2a6ec187dc..9c5ca9b97b 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -28,8 +28,8 @@ struct ThreadBlock }; static_assert(ckb::ThreadBlockDescriptor); -// Describe gridwise GEMM parameters. -struct GridwiseGemm +// Describe gridwise XDL GEMM parameters. +struct GridwiseXdlGemm { // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! size_t ak1 = 0; @@ -39,7 +39,26 @@ struct GridwiseGemm size_t m_xdl_per_wave = 0; size_t n_xdl_per_wave = 0; }; -static_assert(ckb::GridwiseGemmDescriptor); +static_assert(ckb::GridwiseXdlGemmDescriptor); + +// Describe gridwise WMMA GEMM parameters. +struct GridwiseWmmaGemm +{ + size_t k1 = 0; + size_t m_per_wmma = 0; + size_t n_per_wmma = 0; + size_t m_wmma_per_wave = 0; + size_t n_wmma_per_wave = 0; + GridwiseGemmPipelineVersion pipeline_version; +}; +static_assert(ckb::GridwiseWmmaGemmDescriptor); + +struct BlockGemm +{ + BlockGemmPipelineVersion pipeline_version; + BlockGemmPipelineScheduler scheduler; +}; +static_assert(ckb::BlockGemmDescriptor); // Describe Aand B block transfer thread cluster lengths. struct BlockTransfer @@ -72,8 +91,8 @@ static_assert(LdsTransferDescriptor); struct Epilogue { - size_t m_xdl_per_wave_per_shuffle; - size_t n_xdl_per_wave_per_shuffle; + size_t m_per_wave_per_shuffle; + size_t n_per_wave_per_shuffle; size_t scalar_per_vector; }; static_assert(EpilogueDescriptor); @@ -98,22 +117,101 @@ struct BlockTransferABC AccessOrder src_access_order_b; }; -struct ConvAlgorithm +struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { ThreadBlock thread_block; - GridwiseGemm gridwise_gemm; + GridwiseXdlGemm gridwise_gemm; BlockTransferABC block_transfer; - BlockGemmPipelineVersion pipeline_version; ConvFwdSpecialization fwd_specialization; + GemmSpecialization gemm_specialization; + BlockGemm block_gemm; }; -static_assert(ckb::ConvAlgorithmDescriptor); -static_assert(ckb::SpecifiesThreadBlock); -static_assert(ckb::SpecifiesGridwiseGemm); -static_assert(ckb::SpecifiesBlockTransfer); -static_assert(ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder); -static_assert(ckb::SpecifiesSourceAccessOrder); -static_assert(ckb::SpecifiesGemmPipelineVersion); -static_assert(ckb::SpecifiesFwdConcSpecialization); +static_assert( + ckb::ConvAlgorithmDescriptor); +static_assert( + ckb::SpecifiesThreadBlock); +static_assert( + ckb::SpecifiesGridwiseXdlGemm); +static_assert( + ckb::SpecifiesBlockTransfer); +static_assert( + ckb::SpecifiesLdsTransfer); +static_assert(ckb::SpecifiesThreadClusterAccessOrder< + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); +static_assert( + ckb::SpecifiesSourceAccessOrder); +static_assert(ckb::SpecifiesFwdConcSpecialization< + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); +static_assert( + ckb::SpecifiesBlockGemm); +static_assert(ckb::SpecifiesGemmSpecialization< + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); + +struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +{ + ThreadBlock thread_block; + GridwiseXdlGemm gridwise_gemm; + BlockTransferABC block_transfer; + ConvFwdSpecialization fwd_specialization; + GemmSpecialization gemm_specialization; + size_t num_gemm_k_prefetch_stages; + size_t num_groups_to_merge; + LoopScheduler loop_scheduler; +}; +static_assert( + ckb::ConvAlgorithmDescriptor); +static_assert( + ckb::SpecifiesThreadBlock); +static_assert( + ckb::SpecifiesGridwiseXdlGemm); +static_assert( + ckb::SpecifiesBlockTransfer); +static_assert( + ckb::SpecifiesLdsTransfer); +static_assert(ckb::SpecifiesThreadClusterAccessOrder< + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>); +static_assert( + ckb::SpecifiesSourceAccessOrder); +static_assert(ckb::SpecifiesFwdConcSpecialization< + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>); +static_assert( + ckb::SpecifiesNumPrefetchStages); +static_assert( + ckb::SpecifiesGemmSpecialization); +static_assert( + ckb::SpecifiesLoopScheduler); +static_assert( + ckb::SpecifiesNumGroupsToMerge); + +struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle +{ + ThreadBlock thread_block; + GridwiseWmmaGemm gridwise_gemm; + BlockTransferABC block_transfer; + ConvFwdSpecialization fwd_specialization; + GemmSpecialization gemm_specialization; + size_t num_gemm_k_prefetch_stages; + LoopScheduler loop_scheduler; +}; +static_assert( + ckb::ConvAlgorithmDescriptor); +static_assert(ckb::SpecifiesThreadBlock); +static_assert( + ckb::SpecifiesGridwiseWmmaGemm); +static_assert( + ckb::SpecifiesBlockTransfer); +static_assert(ckb::SpecifiesLdsTransfer); +static_assert(ckb::SpecifiesThreadClusterAccessOrder< + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>); +static_assert( + ckb::SpecifiesSourceAccessOrder); +static_assert( + ckb::SpecifiesFwdConcSpecialization); +static_assert( + ckb::SpecifiesNumPrefetchStages); +static_assert( + ckb::SpecifiesGemmSpecialization); +static_assert( + ckb::SpecifiesLoopScheduler); } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index cd3943d26f..d18a008015 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include @@ -15,14 +18,14 @@ template -constexpr void run_test() +constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() { - constexpr GridwiseGemm FwdGemmParams{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 32, - .n_per_xdl = 32, - .m_xdl_per_wave = 4, - .n_xdl_per_wave = 4}; + constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 32, + .n_per_xdl = 32, + .m_xdl_per_wave = 4, + .n_xdl_per_wave = 4}; constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, @@ -40,19 +43,24 @@ constexpr void run_test() .lds_dst_scalar_per_vector = 8, .is_direct_load = false, .lds_padding = false}, - .epilogue_c = {.m_xdl_per_wave_per_shuffle = 1, - .n_xdl_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, .block_transfer_access_order_a = {1, 0, 2}, .block_transfer_access_order_b = {1, 0, 2}, .src_access_order_a = {1, 0, 2}, .src_access_order_b = {1, 0, 2}}; - constexpr ConvAlgorithm FwdConvAlgorithm{.thread_block = FwdThreadBlock, - .gridwise_gemm = FwdGemmParams, - .block_transfer = FwdBlockTransfer, - .pipeline_version = FwdPipelineVersion, - .fwd_specialization = FwdConvSpecialization}; + constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion, + .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + + constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .gridwise_gemm = FwdGemmParams, + .block_transfer = FwdBlockTransfer, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .block_gemm = BlockGemmDesc}; using Builder = ConvBuilder; @@ -88,4 +96,143 @@ constexpr void run_test() EXPECT_NE(invoker_ptr, nullptr); } +template +constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle() +{ + constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 32, + .n_per_xdl = 32, + .m_xdl_per_wave = 2, + .n_xdl_per_wave = 1}; + + constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 16, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + + constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .gridwise_gemm = FwdGemmParams, + .block_transfer = FwdBlockTransfer, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .num_gemm_k_prefetch_stages = 1, + .num_groups_to_merge = 2, + .loop_scheduler = LoopScheduler::DEFAULT}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + +template +constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() +{ + constexpr GridwiseWmmaGemm FwdGemmParams{.k1 = 8, + .m_per_wmma = 32, + .n_per_wmma = 32, + .m_wmma_per_wave = 2, + .n_wmma_per_wave = 1, + .pipeline_version = GridwiseGemmPipelineVersion::V1}; + + constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + + constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .gridwise_gemm = FwdGemmParams, + .block_transfer = FwdBlockTransfer, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .num_gemm_k_prefetch_stages = 1, + .loop_scheduler = LoopScheduler::DEFAULT}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Wmma_CShuffle")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + } // namespace ck_tile::builder::test_utils