From 91f7d4ac750b2e95bf56a586eaf09d94f8f5e178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Thu, 13 Nov 2025 18:47:25 +0200 Subject: [PATCH] [CK_BUILDER] Forward convolution builder improvements (#3179) Proposed changes Improve the forward convolution builder implementation and addressed leftover feedback left from PR #3138. Main changes Refactored tests such that they reflect better the builder pattern. The templates and types for the convolution algorithm concepts are created via factory that facilitates programmatic creation of the device op instances. Moved tests into anonymous namespace. The convolution factory had lot of if-else constructs when CK Builder types were converted into CK library types. I had initially trouble in using static_assert in the default branch of switch as the static_assert was evaluated at compile time even for valid types. However, if we change the static_assert to throw "", it will result in a compile-time error only if the default branch is actually hit. This assumes that the function is consteval. Hence, changed all conversions in the convolution factory to use switch, which is more intuitive. Removed the explicit device op definition from convolution signature and the corresponding predicate file. The device ops are defined by the corresponding concepts. This allowed to remove lot of boilerplate code from the convolution factory. Adde inheritance and convolution algorithm specialization to handle device ops that are specialization of a more generic ones. The large tensor support is more naturally expressed by this pattern. Added support for the FP8 data type. * WIP: Builder for expected test results. * Improve ckb fwd conv instance tests. * clang-format * Change if-else statements into switch in conv factory. * Fix clang-formatting. * Removed unnecessary includes. * Added missing copyright. * Remove explicit device op flag from from convolution signature. * Add missing concept. * Fix build. * clang-format * Add test for building conv fwd FP8 instances. * Add missing header to instance traits. * Clean-up recently added instances. * Introduce inheritance and specialization. * Use builder to build conv algorithm templates and types. * clang-format * Fix conv description tests. --------- Co-authored-by: John Shumway [ROCm/composable_kernel commit: 7d57bc169f8206f06bc516a7f930f388def32347] --- .../builder/conv_algorithm_concepts.hpp | 87 ++-- .../include/ck_tile/builder/conv_factory.hpp | 403 +++++------------- .../builder/conv_signature_concepts.hpp | 18 +- .../builder/conv_signature_predicates.hpp | 190 --------- .../ck_tile/builder/device_op_types.hpp | 22 - ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 1 + .../builder/include/ck_tile/builder/types.hpp | 106 +---- experimental/builder/test/CMakeLists.txt | 10 +- .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 43 +- .../test/conv/test_ckb_conv_fwd_1d_fp16.cpp | 38 +- .../test/conv/test_ckb_conv_fwd_1d_i8.cpp | 38 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 73 ++-- .../conv/test_ckb_conv_fwd_2d_dl_fp16.cpp | 90 ++-- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 44 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 44 +- .../test/conv/test_ckb_conv_fwd_2d_fp8.cpp | 35 ++ ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 75 ++-- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 42 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 44 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 44 +- .../test/impl/conv_algorithm_types.hpp | 337 +++++++++------ .../test/impl/conv_signature_types.hpp | 1 - .../builder/test/test_conv_description.cpp | 2 - .../test/utils/ckb_conv_test_common.hpp | 383 ----------------- .../test/utils/ckb_conv_test_configs.hpp | 184 ++++++++ .../test/utils/ckb_conv_test_utils.hpp | 31 ++ 26 files changed, 946 insertions(+), 1439 deletions(-) delete mode 100644 experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp delete mode 100644 experimental/builder/include/ck_tile/builder/device_op_types.hpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp delete mode 100644 experimental/builder/test/utils/ckb_conv_test_common.hpp create mode 100644 experimental/builder/test/utils/ckb_conv_test_configs.hpp create mode 100644 experimental/builder/test/utils/ckb_conv_test_utils.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 6006efe4f8..ea67d5ccc2 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -95,7 +95,8 @@ concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; -// No requirements yet for a ConvAlgorithm concept. +// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this +// concept. template concept ConvAlgorithmDescriptor = std::is_class_v; @@ -183,6 +184,12 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; +template +concept SpecifiesLargeTensorSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; +}; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ @@ -204,9 +211,9 @@ concept DlThreadClusterDescriptor = requires(T t) { { t.n1_xs } -> std::convertible_to>; }; -// Concept for DL block transfer K0_M0_M1_K1 format +// Concept for DL block transfer template -concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) { +concept DlBlockTransferDescriptor = requires(T t) { { t.thread_slice_lengths } -> std::convertible_to>; { t.thread_cluster_lengths } -> std::convertible_to>; { t.thread_cluster_arrange_order } -> std::convertible_to>; @@ -216,21 +223,9 @@ concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) { { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; -// Concept for DL block transfer K0_N0_N1_K1 format +// Concept for DL epilogue template -concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) { - { t.thread_slice_lengths } -> std::convertible_to>; - { t.thread_cluster_lengths } -> std::convertible_to>; - { t.thread_cluster_arrange_order } -> std::convertible_to>; - { t.src_access_order } -> std::convertible_to>; - { t.src_vector_tensor_lengths } -> std::convertible_to>; - { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; - { t.dst_vector_tensor_lengths } -> std::convertible_to>; -}; - -// Concept for DL C thread transfer -template -concept DlCThreadTransferDescriptor = requires(T t) { +concept DlEpilogueDescriptor = requires(T t) { { t.src_dst_access_order } -> std::convertible_to>; { t.src_dst_vector_dim } -> std::convertible_to; { t.dst_scalar_per_vector } -> std::convertible_to; @@ -239,31 +234,63 @@ concept DlCThreadTransferDescriptor = requires(T t) { // Concept to check if algorithm specifies DL thread config template concept SpecifiesDlThreadConfig = requires { - { T::dl_thread_config } -> DlThreadConfigDescriptor; + { T::thread_config } -> DlThreadConfigDescriptor; }; // Concept to check if algorithm specifies DL thread cluster template concept SpecifiesDlThreadCluster = requires { - { T::dl_thread_cluster } -> DlThreadClusterDescriptor; + { T::thread_cluster } -> DlThreadClusterDescriptor; }; -// Concept to check if algorithm specifies DL A block transfer +// Concept to check if algorithm specifies DL block transfer template -concept SpecifiesDlBlockTransferA = requires { - { T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor; -}; - -// Concept to check if algorithm specifies DL B block transfer -template -concept SpecifiesDlBlockTransferB = requires { - { T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor; +concept SpecifiesDlBlockTransfer = requires { + { T::block_transfer_a } -> DlBlockTransferDescriptor; + { T::block_transfer_b } -> DlBlockTransferDescriptor; }; // Concept to check if algorithm specifies DL C thread transfer template -concept SpecifiesDlCThreadTransfer = requires { - { T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor; +concept SpecifiesDlEpilogue = requires { + { T::epilogue_c } -> DlEpilogueDescriptor; }; +/******************************************** */ +/* Concepts for the different device ops */ +/******************************************** */ + +template +concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && SpecifiesBlockGemm; + +template +concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && + SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; + +template +concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; + +template +concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; + +template +concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle && + SpecifiesLargeTensorSupport; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index e40199987d..5b3fb0e6a8 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -256,6 +256,19 @@ struct ConvTensorTypes using EDataType = int8_t; }; +template <> +struct ConvTensorTypes +{ + using ADataType = ck::f8_t; + using AComputeType = ck::f8_t; + using BDataType = ck::f8_t; + using BComputeType = ck::f8_t; + using CShuffleDataType = ck::f8_t; + using DsDataTypes = ck::Tuple<>; + using AccDataType = float; + using EDataType = ck::f8_t; +}; + template struct ElementwiseOps { @@ -302,49 +315,31 @@ struct BlockGemmSpec }; template -constexpr BlockGemmSpec SetBlockGemm() +consteval BlockGemmSpec SetBlockGemm() { constexpr auto& BG = ALGORITHM.block_gemm; ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; - if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE) + switch(BG.scheduler) { - scheduler = ck::BlockGemmPipelineScheduler::Intrawave; - } - else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE) - { - scheduler = ck::BlockGemmPipelineScheduler::Interwave; - } - else - { - static_assert(false, "Unknown PipelineScheduler"); + case PipelineScheduler::INTRAWAVE: scheduler = ck::BlockGemmPipelineScheduler::Intrawave; break; + case PipelineScheduler::INTERWAVE: scheduler = ck::BlockGemmPipelineScheduler::Interwave; break; + case PipelineScheduler::DEFAULT: throw "Block GEMM scheduler must be Intrawave or Interwave."; + default: throw "Unknown PipelineScheduler"; } - if constexpr(BG.pipeline_version == PipelineVersion::V1) + switch(BG.pipeline_version) { - version = ck::BlockGemmPipelineVersion::v1; - } - else if constexpr(BG.pipeline_version == PipelineVersion::V2) - { - version = ck::BlockGemmPipelineVersion::v2; - } - else if constexpr(BG.pipeline_version == PipelineVersion::V3) - { - version = ck::BlockGemmPipelineVersion::v3; - } - else if constexpr(BG.pipeline_version == PipelineVersion::V4) - { - version = ck::BlockGemmPipelineVersion::v4; - } - else if constexpr(BG.pipeline_version == PipelineVersion::V5) - { - version = ck::BlockGemmPipelineVersion::v5; - } - else - { - static_assert(false, "Unknown PipelineVersion"); + case PipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break; + case PipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break; + case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; + case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; + case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM."; + default: throw "Unknown PipelineVersion"; } return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; @@ -453,18 +448,13 @@ template consteval ck::LoopScheduler SetLoopScheduler() { constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; - - if constexpr(loop_scheduler == PipelineScheduler::DEFAULT) + using ck_loop_sched = ck::LoopScheduler; + switch(loop_scheduler) { - return ck::LoopScheduler::Default; - } - else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE) - { - return ck::LoopScheduler::Interwave; - } - else - { - static_assert(false, "Unknown PipelineScheduler"); + case PipelineScheduler::DEFAULT: return ck_loop_sched::Default; + case PipelineScheduler::INTERWAVE: return ck_loop_sched::Interwave; + case PipelineScheduler::INTRAWAVE: throw "LoopScheduler must be either DEFAULT or INTERWAVE."; + default: throw "Unknown PipelineScheduler"; } } @@ -472,29 +462,16 @@ template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; - if constexpr(pipeline_version == PipelineVersion::V1) + using ck_pipeline = ck::PipelineVersion; + switch(pipeline_version) { - return ck::PipelineVersion::v1; - } - else if constexpr(pipeline_version == PipelineVersion::V2) - { - return ck::PipelineVersion::v2; - } - else if constexpr(pipeline_version == PipelineVersion::V3) - { - static_assert(false, "V3 is used only for stream-K."); - } - else if constexpr(pipeline_version == PipelineVersion::V4) - { - return ck::PipelineVersion::v4; - } - else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY) - { - return ck::PipelineVersion::weight_only; - } - else - { - static_assert(false, "Unknown PipelineVersion"); + case PipelineVersion::V1: return ck_pipeline::v1; + case PipelineVersion::V2: return ck_pipeline::v2; + case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K."; + case PipelineVersion::V4: return ck_pipeline::v4; + case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM."; + case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; + default: throw "Unknown GridwiseGemmPipelineVersion"; } } @@ -502,74 +479,27 @@ template consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization() { constexpr auto gemm_spec = ALGORITHM.gemm_specialization; + using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization; - if constexpr(gemm_spec == GemmSpecialization::Default) + switch(gemm_spec) { - return ck::tensor_operation::device::GemmSpecialization::Default; - } - else if constexpr(gemm_spec == GemmSpecialization::MPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::KPadding) - { - return ck::tensor_operation::device::GemmSpecialization::KPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MKPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MKPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NKPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NKPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNKPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNKPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::OPadding) - { - return ck::tensor_operation::device::GemmSpecialization::OPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::KOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::KOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MKOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MKOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NKOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NKOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNKOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNKOPadding; - } - else - { - static_assert(false, "Unknown GemmSpecialization"); + case GemmSpecialization::Default: return ck_gemm_spec::Default; + case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding; + case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding; + case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding; + case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding; + case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding; + case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding; + case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding; + case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding; + case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding; + case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding; + case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding; + case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding; + case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding; + case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding; + case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding; + default: throw "Unknown GemmSpecialization"; } } @@ -577,30 +507,15 @@ template consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { constexpr auto version = ALGORITHM.pipeline_version; - - if constexpr(version == PipelineVersion::V1) + using ck_pipeline = ck::BlockGemmPipelineVersion; + switch(version) { - return ck::BlockGemmPipelineVersion::v1; - } - else if constexpr(version == PipelineVersion::V2) - { - return ck::BlockGemmPipelineVersion::v2; - } - else if constexpr(version == PipelineVersion::V3) - { - return ck::BlockGemmPipelineVersion::v3; - } - else if constexpr(version == PipelineVersion::V4) - { - return ck::BlockGemmPipelineVersion::v4; - } - else if constexpr(version == PipelineVersion::V5) - { - return ck::BlockGemmPipelineVersion::v5; - } - else - { - static_assert(false, "Unknown PipelineVersion"); + case PipelineVersion::V1: return ck_pipeline::v1; + case PipelineVersion::V2: return ck_pipeline::v2; + case PipelineVersion::V3: return ck_pipeline::v3; + case PipelineVersion::V4: return ck_pipeline::v4; + case PipelineVersion::V5: return ck_pipeline::v5; + default: throw "Unknown block GEMM PipelineVersion"; } } @@ -608,26 +523,14 @@ template consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization() { constexpr auto specialization = ALGORITHM.fwd_specialization; - - if constexpr(specialization == ConvFwdSpecialization::DEFAULT) + using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; + switch(specialization) { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - } - else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - } - else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - } - else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3) - { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3; - } - else - { - static_assert(false, "Unknown ConvFwdSpecialization"); + case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; + default: throw "Unknown ConvFwdSpecialization"; } } @@ -639,7 +542,12 @@ namespace ck_tile::builder { template -struct ConvFactory; +struct ConvFactory +{ + // This will trigger if a specialization for the given convolution direction is not found. + // We should always catch this in an earlier validation check. + static_assert(false, "Unsupported device operation."); +}; // Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance // of a grouped forward convolution kernel. @@ -647,7 +555,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -658,26 +566,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseXdlGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesBlockGemm, - "The convolution algorithm descriptor must specify block gemm pipeline."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load == ALGORITHM.block_transfer.lds_transfer_b.is_direct_load, "A and B block transfers must both be direct load or not."); @@ -769,7 +657,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -780,31 +668,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseXdlGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); - static_assert(SpecifiesNumPrefetchStages, - "The convolution algorithm descriptor must specify number of prefetch stages."); - static_assert(SpecifiesLoopScheduler, - "The convolution algorithm descriptor must specify loop scheduler."); - static_assert(SpecifiesNumGroupsToMerge, - "The convolution algorithm descriptor must specify number of groups to merge."); - static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = @@ -891,7 +754,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -902,27 +765,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseWmmaGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesNumPrefetchStages, - "The convolution algorithm descriptor must specify number of prefetch stages."); - static_assert(SpecifiesLoopScheduler, - "The convolution algorithm descriptor must specify loop scheduler."); - static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = @@ -1008,7 +850,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -1019,24 +861,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); - static_assert(SpecifiesDlThreadConfig, - "DL algorithm must specify thread config."); - static_assert(SpecifiesDlThreadCluster, - "DL algorithm must specify thread cluster."); - static_assert(SpecifiesDlBlockTransferA, - "DL algorithm must specify A block transfer."); - static_assert(SpecifiesDlBlockTransferB, - "DL algorithm must specify B block transfer."); - static_assert(SpecifiesDlCThreadTransfer, - "DL algorithm must specify C thread transfer."); - static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = @@ -1045,7 +869,7 @@ struct ConvFactory static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); // DL-specific parameters from algorithm descriptor - static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config; + static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config; static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; @@ -1053,12 +877,12 @@ struct ConvFactory static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; // Thread cluster from descriptor - static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster; + static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster; using M1N1ThreadClusterM1Xs = to_sequence_v; using M1N1ThreadClusterN1Xs = to_sequence_v; // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format - static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a; + static constexpr auto DL_A_TRANSFER = ALGORITHM.block_transfer_a; using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = to_sequence_v; using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = @@ -1074,7 +898,7 @@ struct ConvFactory to_sequence_v; // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format - static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b; + static constexpr auto DL_B_TRANSFER = ALGORITHM.block_transfer_b; using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = to_sequence_v; using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = @@ -1090,7 +914,7 @@ struct ConvFactory to_sequence_v; // C Thread Transfer from descriptor - static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer; + static constexpr auto DL_C_TRANSFER = ALGORITHM.epilogue_c; using CThreadTransferSrcDstAccessOrder = to_sequence_v; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; static constexpr ck::index_t CThreadTransferDstScalarPerVector = @@ -1148,7 +972,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -1159,45 +983,24 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseXdlGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); - static_assert(SpecifiesNumPrefetchStages, - "The convolution algorithm descriptor must specify number of prefetch stages."); - static_assert(SpecifiesLoopScheduler, - "The convolution algorithm descriptor must specify loop scheduler."); + static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; static constexpr auto FWD_CONV_SPECIALIZATION = - factory_internal::SetFwdConvSpecialization(); + factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = - factory_internal::SetGemmSpecialization(); + factory_internal::SetGemmSpecialization(); static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; - static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); - static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvABlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = - factory_internal::SetCBlockTransfer(); + factory_internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. static_assert(InputVectorTransferLimits); @@ -1227,7 +1030,7 @@ struct ConvFactory typename Ops::CDEElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, + BASE_ALGORITHM.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 742dfbb89c..983273b439 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -21,7 +21,6 @@ #include #include "ck_tile/builder/types.hpp" -#include "ck_tile/builder/conv_signature_predicates.hpp" namespace ck_tile::builder { @@ -41,9 +40,6 @@ template concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); -template -concept ConvDeviceOp = std::same_as, GroupConvDeviceOp>; - template concept ConvLayout = std::same_as, GroupConvLayout>; @@ -55,7 +51,6 @@ concept ConvSignatureDescriptor = requires(T t) { { t.layout } -> ConvLayout; { t.data_type } -> std::convertible_to; { t.elementwise_operation } -> std::convertible_to; - { t.device_operation } -> ConvDeviceOp; }; // Concept to validate a convolution signature's values. @@ -63,7 +58,18 @@ template concept ValidConvSignature = requires { requires ConvSpatialDim; requires ConvDataType; - requires IsValidConvDeviceOp; }; +// Predicate for forward convolution. +template +concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); + +// Predicate for backward data convolution. +template +concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); + +// Predicate for backward weight convolution. +template +concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp deleted file mode 100644 index 3869c7b538..0000000000 --- a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/builder/types.hpp" - -namespace ck_tile::builder { - -/********************************************** - * Conv Direction Predicates - **********************************************/ - -// Predicate for forward convolution. -template -concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); - -// Predicate for backward data convolution. -template -concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); - -// Predicate for backward weight convolution. -template -concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); - -/********************************************** - * Conv Fwd Device Op Predicates - **********************************************/ - -// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); - -// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); - -// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); - -// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); - -// Generic predicate to check if signature uses any forward convolution device operation. -template -concept ConvDeviceOpIsForward = - ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; - -/********************************************** - * Conv Bwd Weight Device Op Predicates - **********************************************/ - -// Predicate for DeviceGroupedConvBwdWeight operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); - -// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); - -// Predicate for DeviceGroupedConvBwdWeightMultipleD operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); - -// Predicate for DeviceGroupedConvBwdWeight_Dl operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); - -// Generic predicate to check if signature uses any backward weight convolution device operation. -template -concept ConvDeviceOpIsBackwardWeight = - ConvDeviceOpIs_DeviceGroupedConvBwdWeight || - ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 || - ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl; - -/********************************************** - * Conv Bwd Data Device Op Predicates - **********************************************/ - -// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = - ConvDirectionIsBackwardData && - (Sig.device_operation._bwd_data == - BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); - -// Predicate for DeviceGroupedConvBwdDataMultipleD operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = - ConvDirectionIsBackwardData && - (Sig.device_operation._bwd_data == - BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); - -// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = - ConvDirectionIsBackwardData && - (Sig.device_operation._bwd_data == - BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); - -// Generic predicate to check if signature uses any backward data convolution device operation. -template -concept ConvDeviceOpIsBackwardData = - ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 || - ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD || - ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; - -/********************************************** - * Generic Device Op Predicates - **********************************************/ - -// Generic predicate to check if signature uses any device operation. -template -concept IsValidConvDeviceOp = ConvDeviceOpIsForward || ConvDeviceOpIsBackwardData || - ConvDeviceOpIsBackwardWeight; - -} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/device_op_types.hpp b/experimental/builder/include/ck_tile/builder/device_op_types.hpp deleted file mode 100644 index 0e779fdf4e..0000000000 --- a/experimental/builder/include/ck_tile/builder/device_op_types.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -namespace ck_tile::builder { - -// Enumeration for CK Device Operation types. -// This allows the builder to select which device operation template to instantiate -// based on the user's requirements. -enum class DeviceOpType -{ - // Forward Convolution - Non-grouped - CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet) - - // Forward Convolution - Grouped - GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle - GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to: - // DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 -}; - -} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 8756825c3f..5c267f2552 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -14,6 +14,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index a58c994288..fa2b99ef56 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -74,52 +74,6 @@ enum class ConvDirection BACKWARD_WEIGHT }; -// Forward convolution device operations. -enum class FwdGroupConvDeviceOperation -{ - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor -}; - -// Backward data convolution device operations. -enum class BwdDataGroupConvDeviceOperation -{ - DeviceGroupedConvBwdDataMultipleD, - DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 -}; - -// Backward weight convolution device operations. -enum class BwdWeightGroupConvDeviceOperation -{ - DeviceGroupedConvBwdWeight, - DeviceGroupedConvBwdWeight_Dl, - DeviceGroupedConvBwdWeight_Xdl_CShuffle, - DeviceGroupedConvBwdWeight_Xdl_CShuffleV3, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, - DeviceGroupedConvBwdWeightMultipleD, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, -}; - -// Structural type for device operation -struct GroupConvDeviceOp -{ - union - { - FwdGroupConvDeviceOperation _fwd; - BwdDataGroupConvDeviceOperation _bwd_data; - BwdWeightGroupConvDeviceOperation _bwd_weight; - }; - - constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {} - constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {} - constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {} -}; - // Fused element-wise operations. enum class ElementwiseOperation { @@ -219,6 +173,11 @@ enum class PipelineScheduler INTERWAVE }; +enum class ConvAlgorithmSpecialization +{ + LARGE_TENSOR +}; + // ostream operator overloads for enum classes inline std::ostream& operator<<(std::ostream& os, DataType dt) { @@ -286,61 +245,6 @@ inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout) } } -inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op) -{ - using enum FwdGroupConvDeviceOperation; - switch(op) - { - case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK: - return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"; - case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle: - return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"; - case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle: - return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"; - case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3: - return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"; - case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor: - return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"; - default: return os << "Unknown"; - } -} - -inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op) -{ - using enum BwdDataGroupConvDeviceOperation; - switch(op) - { - case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD"; - case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle: - return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"; - case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1: - return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"; - default: return os << "Unknown"; - } -} - -inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op) -{ - using enum BwdWeightGroupConvDeviceOperation; - switch(op) - { - case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight"; - case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl"; - case DeviceGroupedConvBwdWeight_Xdl_CShuffle: - return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; - case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3: - return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; - case DeviceGroupedConvBwdWeight_Wmma_CShuffle: - return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; - case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle: - return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"; - case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD"; - case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle: - return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; - default: return os << "Unknown"; - } -} - inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op) { using enum ElementwiseOperation; diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 43c4fd4857..5044d223de 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -39,7 +39,8 @@ add_ck_builder_test(test_ckb_get_instance_string add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_1d_fp16.cpp conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp + conv/test_ckb_conv_fwd_1d_i8.cpp + conv/test_ckb_conv_fwd_2d_fp8.cpp conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -47,7 +48,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp conv/test_ckb_conv_fwd_3d_bf16.cpp conv/test_ckb_conv_fwd_3d_fp16.cpp - conv/test_ckb_conv_fwd_3d_fp32.cpp) + conv/test_ckb_conv_fwd_3d_fp32.cpp + ) function(add_ck_factory_test test_name) add_ck_builder_test(${test_name} ${ARGN}) @@ -66,10 +68,10 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) -add_ck_builder_test(test_conv_traits +add_ck_builder_test(test_ckb_conv_traits conv/test_conv_traits.cpp) -add_ck_builder_test(test_conv_description +add_ck_builder_test(test_ckb_conv_description test_conv_description.cpp) # Function to add all test_ckb targets to a list diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index 123034eb77..bb0c767bbd 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -1,34 +1,39 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 1D BF16 (channels-first) with Pipeline V2 and FILTER_1X1_STRIDE1_PAD0 specialization and SCALE // elementwise op TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, - .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::SCALE, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, + .data_type = DataType::BF16, + .elementwise_operation = ElementwiseOperation::SCALE}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v2_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V2, - ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Filter1x1Stride1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v2"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp index a83ca84297..d391c1a74d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp @@ -1,31 +1,35 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 1D FP16 (channels-last) with DEFAULT specialization TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout1D::NWGC_GKXC_NWGK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout1D::NWGC_GKXC_NWGK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 64, - .tile_size = {.m = 64, .n = 32, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle(); + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "64, 64, 32, 32", "Default"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp index 3ceac2a047..7206c768d8 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp @@ -1,31 +1,35 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 1D I8 (channels-last) with and DEFAULT specialization TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout1D::GNWC_GKXC_GNWK, - .data_type = DataType::I8, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout1D::GNWC_GKXC_GNWK, + .data_type = DataType::I8, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 128, - .tile_size = {.m = 64, .n = 64, .k = 64}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} + .with_thread_block(FwdThreadBlock_128_64x64x64) + .with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave) + .with_block_transfer(FwdBlockTransfer_4x32x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); - run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle(); + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", "128, 64, 64, 64", "Default"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 240746f546..d0bc1d7a6d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -1,54 +1,63 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, - .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::BF16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v1_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Default", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v1"}); } // 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3 TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, - .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::BF16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v5_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "Filter3x3", + "BlkGemmPipelineVersion: v5"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp index 12730bab19..0a337f3a7b 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -1,69 +1,59 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 16}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} + .with_thread_block(FwdThreadBlock_256_128x128x16) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) + .with_dl_thread_cluster(DlThreadCluster_8x2) + .with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB) + .with_dl_epilogue(DlEpilogueC); - run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); -} - -TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC) -{ - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; - - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 16}}; - - run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Default"}); } TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 16}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} + .with_thread_block(FwdThreadBlock_256_128x128x16) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) + .with_dl_thread_cluster(DlThreadCluster_8x2) + .with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB) + .with_dl_epilogue(DlEpilogueC); - run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< - FwdConvSignature, - FwdThreadBlock, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Filter1x1Pad0"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index 6366016707..798c6cfa2d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -1,32 +1,38 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v3_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V3, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Filter1x1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v3"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index 7b303a7bde..a8313ff510 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -1,32 +1,38 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, - .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, + .data_type = DataType::FP32, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v4_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V4, - ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 128, 128, 32", + "Filter1x1Stride1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v4"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp new file mode 100644 index 0000000000..39319bb79e --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp @@ -0,0 +1,35 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +// 2D FP8 NHWGC (channels-last) with Pipeline V1 and DEFAULT +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::FP8, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + .with_thread_block(FwdThreadBlock_256_256x128x32) + .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1_fp8) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "256, 256, 128, 32", "Default"}); +} + +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 0216c5907d..6c43678bf1 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -1,53 +1,64 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ + .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + .with_thread_block(FwdThreadBlock_256_256x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, + GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; - run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - FwdConvSignature, - FwdThreadBlock, - ConvFwdSpecialization::DEFAULT>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", + "256, 256, 128, 32", + "Default"}); } TEST( FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 128, - .tile_size = {.m = 128, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ + .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + .with_thread_block(FwdThreadBlock_128_128x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; - run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - FwdConvSignature, - FwdThreadBlock, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", + "128, 128, 128, 32", + "Filter1x1Pad0"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index b40dd0b0d7..2392d1efff 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -1,32 +1,38 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, - .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, + .data_type = DataType::BF16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v3_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Default", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v3"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index e0dad4e1a1..52e153098e 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -1,33 +1,39 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0 TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v4_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V4, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 128, 128, 32", + "Filter1x1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v4"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 43ffb3f89a..5d1656924c 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -1,33 +1,39 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0 TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, - .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, + .data_type = DataType::FP32, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v1_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V1, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Filter1x1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v1"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 88c5b5787a..fd756cf06e 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -117,103 +117,6 @@ struct BlockTransferABC AccessOrder src_access_order_b; }; -struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 -{ - ThreadBlock thread_block; - GridwiseXdlGemm gridwise_gemm; - BlockTransferABC block_transfer; - ConvFwdSpecialization fwd_specialization; - GemmSpecialization gemm_specialization; - BlockGemm block_gemm; -}; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert( - ckb::SpecifiesThreadBlock); -static_assert( - ckb::SpecifiesGridwiseXdlGemm); -static_assert( - ckb::SpecifiesBlockTransfer); -static_assert( - ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); -static_assert( - ckb::SpecifiesSourceAccessOrder); -static_assert(ckb::SpecifiesFwdConcSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); -static_assert( - ckb::SpecifiesBlockGemm); -static_assert(ckb::SpecifiesGemmSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); - -struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle -{ - ThreadBlock thread_block; - GridwiseXdlGemm gridwise_gemm; - BlockTransferABC block_transfer; - ConvFwdSpecialization fwd_specialization; - GemmSpecialization gemm_specialization; - size_t num_gemm_k_prefetch_stages; - size_t num_groups_to_merge; - PipelineScheduler loop_scheduler; -}; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert( - ckb::SpecifiesThreadBlock); -static_assert( - ckb::SpecifiesGridwiseXdlGemm); -static_assert( - ckb::SpecifiesBlockTransfer); -static_assert( - ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>); -static_assert( - ckb::SpecifiesSourceAccessOrder); -static_assert(ckb::SpecifiesFwdConcSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>); -static_assert( - ckb::SpecifiesNumPrefetchStages); -static_assert( - ckb::SpecifiesGemmSpecialization); -static_assert( - ckb::SpecifiesLoopScheduler); -static_assert( - ckb::SpecifiesNumGroupsToMerge); - -struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle -{ - ThreadBlock thread_block; - GridwiseWmmaGemm gridwise_gemm; - BlockTransferABC block_transfer; - ConvFwdSpecialization fwd_specialization; - GemmSpecialization gemm_specialization; - size_t num_gemm_k_prefetch_stages; - PipelineScheduler loop_scheduler; -}; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert(ckb::SpecifiesThreadBlock); -static_assert( - ckb::SpecifiesGridwiseWmmaGemm); -static_assert( - ckb::SpecifiesBlockTransfer); -static_assert(ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder< - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>); -static_assert( - ckb::SpecifiesSourceAccessOrder); -static_assert( - ckb::SpecifiesFwdConcSpecialization); -static_assert( - ckb::SpecifiesNumPrefetchStages); -static_assert( - ckb::SpecifiesGemmSpecialization); -static_assert( - ckb::SpecifiesLoopScheduler); - // DL-specific descriptors struct DlThreadConfig { @@ -227,12 +130,12 @@ static_assert(ckb::DlThreadConfigDescriptor); struct DlThreadCluster { - std::array m1_xs; // e.g., {8, 2} - std::array n1_xs; // e.g., {8, 2} + std::array m1_xs; + std::array n1_xs; }; static_assert(ckb::DlThreadClusterDescriptor); -struct DlBlockTransferK0M0M1K1 +struct DlBlockTransfer { std::array thread_slice_lengths; std::array thread_cluster_lengths; @@ -242,56 +145,212 @@ struct DlBlockTransferK0M0M1K1 std::array src_vector_tensor_contiguous_dim_order; std::array dst_vector_tensor_lengths; }; -static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor); +static_assert(ckb::DlBlockTransferDescriptor); -struct DlBlockTransferK0N0N1K1 -{ - std::array thread_slice_lengths; - std::array thread_cluster_lengths; - std::array thread_cluster_arrange_order; - std::array src_access_order; - std::array src_vector_tensor_lengths; - std::array src_vector_tensor_contiguous_dim_order; - std::array dst_vector_tensor_lengths; -}; -static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor); - -struct DlCThreadTransfer +struct DlEpilogue { std::array src_dst_access_order; size_t src_dst_vector_dim; size_t dst_scalar_per_vector; }; -static_assert(ckb::DlCThreadTransferDescriptor); +static_assert(ckb::DlEpilogueDescriptor); -struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +// Factory types + +struct ThreadBlock_ { ThreadBlock thread_block; +}; + +struct XdlGemm_ +{ + GridwiseXdlGemm gridwise_gemm; +}; + +struct WmmaGemm_ +{ + GridwiseWmmaGemm gridwise_gemm; +}; + +struct BlockTransfer_ +{ + BlockTransferABC block_transfer; +}; + +struct ConvSpecialization_ +{ ConvFwdSpecialization fwd_specialization; GemmSpecialization gemm_specialization; - DlThreadConfig dl_thread_config; - DlThreadCluster dl_thread_cluster; - DlBlockTransferK0M0M1K1 dl_block_transfer_a; - DlBlockTransferK0N0N1K1 dl_block_transfer_b; - DlCThreadTransfer dl_c_thread_transfer; }; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert( - ckb::SpecifiesThreadBlock); -static_assert(ckb::SpecifiesFwdConcSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>); -static_assert( - ckb::SpecifiesGemmSpecialization); -static_assert( - ckb::SpecifiesDlThreadConfig); -static_assert( - ckb::SpecifiesDlThreadCluster); -static_assert( - ckb::SpecifiesDlBlockTransferA); -static_assert( - ckb::SpecifiesDlBlockTransferB); -static_assert( - ckb::SpecifiesDlCThreadTransfer); + +struct Prefetch_ +{ + size_t num_gemm_k_prefetch_stages; + size_t num_groups_to_merge; + PipelineScheduler loop_scheduler; +}; + +struct BlockGemm_ +{ + BlockGemm block_gemm; +}; + +struct DlThreadConfig_ +{ + DlThreadConfig thread_config; +}; + +struct DlThreadCluster_ +{ + DlThreadCluster thread_cluster; +}; + +struct DlBlockTransfer_ +{ + DlBlockTransfer block_transfer_a; + DlBlockTransfer block_transfer_b; +}; + +struct DlEpilogue_ +{ + DlEpilogue epilogue_c; +}; + +// Specialization wrapper for large tensor support +template +struct LargeTensorWrapper +{ + BaseAlgorithm base_algorithm; + static constexpr ConvAlgorithmSpecialization specialization = + ConvAlgorithmSpecialization::LARGE_TENSOR; +}; + +// Factory + +template +struct ConvAlgorithmTemplate : Components... +{ + + template + constexpr auto with_thread_block(const TB& tb) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_block = tb; + return result; + } + + template + constexpr auto with_gemm_config(const GemmConfig& gemm) const + { + auto result = *this; + if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } + else if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } + return result; + } + + template + constexpr auto with_block_transfer(const BT& bt) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_transfer = bt; + return result; + } + + constexpr auto with_specializations(ConvFwdSpecialization fwd_spec, + GemmSpecialization gemm_spec) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.fwd_specialization = fwd_spec; + result.gemm_specialization = gemm_spec; + return result; + } + + constexpr auto with_prefetch_config(size_t k_prefetch_stages, + size_t groups_to_merge, + PipelineScheduler scheduler) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.num_gemm_k_prefetch_stages = k_prefetch_stages; + result.num_groups_to_merge = groups_to_merge; + result.loop_scheduler = scheduler; + return result; + } + + template + constexpr auto with_block_gemm(const BG& bg) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_gemm = bg; + return result; + } + + template + constexpr auto with_dl_thread_config(const TC& tc) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_config = tc; + return result; + } + + template + constexpr auto with_dl_thread_cluster(const TCl& tcl) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_cluster = tcl; + return result; + } + + template + constexpr auto with_dl_block_transfer(const BTA& bta, const BTB& btb) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_transfer_a = bta; + result.block_transfer_b = btb; + return result; + } + + constexpr auto with_dl_epilogue(const DlEpilogue& epi) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.epilogue_c = epi; + return result; + } +}; + +// Algorithm types + +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + ConvAlgorithmTemplate; + +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + ConvAlgorithmTemplate; + +using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + ConvAlgorithmTemplate; +using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + ConvAlgorithmTemplate; + +using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + LargeTensorWrapper; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index 5e6684c4cd..71f16aefbe 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -17,7 +17,6 @@ struct ConvSignature GroupConvLayout layout; DataType data_type; ElementwiseOperation elementwise_operation; - GroupConvDeviceOp device_operation; }; static_assert(ConvSignatureDescriptor); diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 97af4af795..733359d491 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -25,8 +25,6 @@ struct ConvSignature ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; ckb::DataType data_type = ckb::DataType::FP16; ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH; - ckb::GroupConvDeviceOp device_operation = - ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; }; static_assert(ckb::ConvSignatureDescriptor); diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp deleted file mode 100644 index 14fae566f6..0000000000 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ /dev/null @@ -1,383 +0,0 @@ -// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include "impl/conv_algorithm_types.hpp" -#include "impl/conv_signature_types.hpp" -#include "ck_tile/builder/conv_builder.hpp" - -namespace ck_tile::builder::test_utils { - -using namespace ck_tile::builder; -using namespace test; - -// Common test implementation -template -constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() -{ - constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 32, - .n_per_xdl = 32, - .m_xdl_per_wave = 4, - .n_xdl_per_wave = 4}; - - constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 8}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; - - constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion, - .scheduler = PipelineScheduler::INTRAWAVE}; - - constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .gridwise_gemm = FwdGemmParams, - .block_transfer = FwdBlockTransfer, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .block_gemm = BlockGemmDesc}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")); - - // Verify pipeline version is correct - if(FwdPipelineVersion == PipelineVersion::V1) - EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos); - else if(FwdPipelineVersion == PipelineVersion::V3) - EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos); - else if(FwdPipelineVersion == PipelineVersion::V4) - EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos); - else if(FwdPipelineVersion == PipelineVersion::V5) - EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -template -constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle() -{ - constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 32, - .n_per_xdl = 32, - .m_xdl_per_wave = 2, - .n_xdl_per_wave = 1}; - - constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 16, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; - - constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .gridwise_gemm = FwdGemmParams, - .block_transfer = FwdBlockTransfer, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .num_gemm_k_prefetch_stages = 1, - .num_groups_to_merge = 2, - .loop_scheduler = PipelineScheduler::DEFAULT}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle")); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -template -constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() -{ - constexpr GridwiseWmmaGemm FwdGemmParams{.k1 = 8, - .m_per_wmma = 32, - .n_per_wmma = 32, - .m_wmma_per_wave = 2, - .n_wmma_per_wave = 1, - .pipeline_version = PipelineVersion::V1}; - - constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; - - constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .gridwise_gemm = FwdGemmParams, - .block_transfer = FwdBlockTransfer, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .num_gemm_k_prefetch_stages = 1, - .loop_scheduler = PipelineScheduler::DEFAULT}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Wmma_CShuffle")); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -template -constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK() -{ - // DL thread configuration - constexpr DlThreadConfig DlThreadCfg{ - .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; - - // DL thread cluster - constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; - - // DL A block transfer - K0_M0_M1_K1 format - constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{ - .thread_slice_lengths = {8, 1, 1, 2}, - .thread_cluster_lengths = {2, 1, 128, 1}, - .thread_cluster_arrange_order = {1, 2, 0, 3}, - .src_access_order = {1, 2, 0, 3}, - .src_vector_tensor_lengths = {4, 1, 1, 2}, - .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, - .dst_vector_tensor_lengths = {1, 1, 1, 2}}; - - // DL B block transfer - K0_N0_N1_K1 format - constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{ - .thread_slice_lengths = {8, 1, 1, 2}, - .thread_cluster_lengths = {2, 1, 128, 1}, - .thread_cluster_arrange_order = {1, 2, 0, 3}, - .src_access_order = {1, 2, 0, 3}, - .src_vector_tensor_lengths = {4, 1, 1, 2}, - .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, - .dst_vector_tensor_lengths = {1, 1, 1, 2}}; - - // DL C thread transfer - constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, - .src_dst_vector_dim = 5, - .dst_scalar_per_vector = 4}; - - constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .dl_thread_config = DlThreadCfg, - .dl_thread_cluster = DlCluster, - .dl_block_transfer_a = DlBlockTransferA, - .dl_block_transfer_b = DlBlockTransferB, - .dl_c_thread_transfer = DlCTransfer}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK")); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor -// Note: Large_Tensor has identical parameters to regular XDL CShuffle -template -constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor() -{ - constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 32, - .n_per_xdl = 32, - .m_xdl_per_wave = 2, - .n_xdl_per_wave = 1}; - - constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 16, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; - - // Large_Tensor uses the same descriptor as regular XDL CShuffle - constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .gridwise_gemm = FwdGemmParams, - .block_transfer = FwdBlockTransfer, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .num_gemm_k_prefetch_stages = 1, - .num_groups_to_merge = 1, - .loop_scheduler = LoopScheduler::DEFAULT}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE( - kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor")); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -} // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp new file mode 100644 index 0000000000..7f2acce9c8 --- /dev/null +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -0,0 +1,184 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "impl/conv_algorithm_types.hpp" +#include "impl/conv_signature_types.hpp" +#include "ck_tile/builder/conv_builder.hpp" + +namespace ck_tile::builder::test_utils { + +using namespace ck_tile::builder; +using namespace test; + +constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{ + .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; + +constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; + +constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + +constexpr DlEpilogue DlEpilogueC{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}; + +constexpr BlockTransferABC FwdBlockTransfer_4x64x1{ + .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + +constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{ + .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + +constexpr BlockTransferABC FwdBlockTransfer_4x16x1{ + .block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 16, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + +constexpr BlockTransferABC FwdBlockTransfer_4x32x1{ + .block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}; + +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}; + +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}; + +constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, + .m_per_wmma = 32, + .n_per_wmma = 32, + .m_wmma_per_wave = 2, + .n_wmma_per_wave = 1, + .pipeline_version = PipelineVersion::V1}; + +constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + +constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64, + .tile_size = {.m = 64, .n = 32, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128, + .tile_size = {.m = 128, .n = 128, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128, + .tile_size = {.m = 64, .n = 64, .k = 64}}; + +constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; + +} // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_test_utils.hpp b/experimental/builder/test/utils/ckb_conv_test_utils.hpp new file mode 100644 index 0000000000..6f9442d067 --- /dev/null +++ b/experimental/builder/test/utils/ckb_conv_test_utils.hpp @@ -0,0 +1,31 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +namespace ck_tile::builder::test_utils { +using namespace test; + +// Common test implementation +template +constexpr void run_test(const std::vector& kernel_instance_components) +{ + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); + + for(const auto& component : kernel_instance_components) + { + EXPECT_THAT(kernel_string, ::testing::HasSubstr(component)); + } +} + +} // namespace ck_tile::builder::test_utils