From 0997e2eb6d12c4697724d83d4b4945017e85b94a Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Thu, 13 Nov 2025 17:13:19 +0000 Subject: [PATCH] Merge commit '7d57bc169f8206f06bc516a7f930f388def32347' into develop --- .../builder/conv_algorithm_concepts.hpp | 87 ++-- .../include/ck_tile/builder/conv_factory.hpp | 403 +++++------------- .../builder/conv_signature_concepts.hpp | 18 +- .../builder/conv_signature_predicates.hpp | 190 --------- .../ck_tile/builder/device_op_types.hpp | 22 - ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 1 + .../builder/include/ck_tile/builder/types.hpp | 106 +---- experimental/builder/test/CMakeLists.txt | 10 +- .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 43 +- .../test/conv/test_ckb_conv_fwd_1d_fp16.cpp | 38 +- .../test/conv/test_ckb_conv_fwd_1d_i8.cpp | 38 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 73 ++-- .../conv/test_ckb_conv_fwd_2d_dl_fp16.cpp | 90 ++-- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 44 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 44 +- .../test/conv/test_ckb_conv_fwd_2d_fp8.cpp | 35 ++ ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 75 ++-- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 42 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 44 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 44 +- .../test/impl/conv_algorithm_types.hpp | 337 +++++++++------ .../test/impl/conv_signature_types.hpp | 1 - .../builder/test/test_conv_description.cpp | 2 - .../test/utils/ckb_conv_test_common.hpp | 383 ----------------- .../test/utils/ckb_conv_test_configs.hpp | 184 ++++++++ .../test/utils/ckb_conv_test_utils.hpp | 31 ++ 26 files changed, 946 insertions(+), 1439 deletions(-) delete mode 100644 experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp delete mode 100644 experimental/builder/include/ck_tile/builder/device_op_types.hpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp delete mode 100644 experimental/builder/test/utils/ckb_conv_test_common.hpp create mode 100644 experimental/builder/test/utils/ckb_conv_test_configs.hpp create mode 100644 experimental/builder/test/utils/ckb_conv_test_utils.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 6006efe4f8..ea67d5ccc2 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -95,7 +95,8 @@ concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; -// No requirements yet for a ConvAlgorithm concept. +// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this +// concept. template concept ConvAlgorithmDescriptor = std::is_class_v; @@ -183,6 +184,12 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; +template +concept SpecifiesLargeTensorSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; +}; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ @@ -204,9 +211,9 @@ concept DlThreadClusterDescriptor = requires(T t) { { t.n1_xs } -> std::convertible_to>; }; -// Concept for DL block transfer K0_M0_M1_K1 format +// Concept for DL block transfer template -concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) { +concept DlBlockTransferDescriptor = requires(T t) { { t.thread_slice_lengths } -> std::convertible_to>; { t.thread_cluster_lengths } -> std::convertible_to>; { t.thread_cluster_arrange_order } -> std::convertible_to>; @@ -216,21 +223,9 @@ concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) { { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; -// Concept for DL block transfer K0_N0_N1_K1 format +// Concept for DL epilogue template -concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) { - { t.thread_slice_lengths } -> std::convertible_to>; - { t.thread_cluster_lengths } -> std::convertible_to>; - { t.thread_cluster_arrange_order } -> std::convertible_to>; - { t.src_access_order } -> std::convertible_to>; - { t.src_vector_tensor_lengths } -> std::convertible_to>; - { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; - { t.dst_vector_tensor_lengths } -> std::convertible_to>; -}; - -// Concept for DL C thread transfer -template -concept DlCThreadTransferDescriptor = requires(T t) { +concept DlEpilogueDescriptor = requires(T t) { { t.src_dst_access_order } -> std::convertible_to>; { t.src_dst_vector_dim } -> std::convertible_to; { t.dst_scalar_per_vector } -> std::convertible_to; @@ -239,31 +234,63 @@ concept DlCThreadTransferDescriptor = requires(T t) { // Concept to check if algorithm specifies DL thread config template concept SpecifiesDlThreadConfig = requires { - { T::dl_thread_config } -> DlThreadConfigDescriptor; + { T::thread_config } -> DlThreadConfigDescriptor; }; // Concept to check if algorithm specifies DL thread cluster template concept SpecifiesDlThreadCluster = requires { - { T::dl_thread_cluster } -> DlThreadClusterDescriptor; + { T::thread_cluster } -> DlThreadClusterDescriptor; }; -// Concept to check if algorithm specifies DL A block transfer +// Concept to check if algorithm specifies DL block transfer template -concept SpecifiesDlBlockTransferA = requires { - { T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor; -}; - -// Concept to check if algorithm specifies DL B block transfer -template -concept SpecifiesDlBlockTransferB = requires { - { T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor; +concept SpecifiesDlBlockTransfer = requires { + { T::block_transfer_a } -> DlBlockTransferDescriptor; + { T::block_transfer_b } -> DlBlockTransferDescriptor; }; // Concept to check if algorithm specifies DL C thread transfer template -concept SpecifiesDlCThreadTransfer = requires { - { T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor; +concept SpecifiesDlEpilogue = requires { + { T::epilogue_c } -> DlEpilogueDescriptor; }; +/******************************************** */ +/* Concepts for the different device ops */ +/******************************************** */ + +template +concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && SpecifiesBlockGemm; + +template +concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && + SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; + +template +concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; + +template +concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConcSpecialization && + SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; + +template +concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle && + SpecifiesLargeTensorSupport; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index e40199987d..5b3fb0e6a8 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -256,6 +256,19 @@ struct ConvTensorTypes using EDataType = int8_t; }; +template <> +struct ConvTensorTypes +{ + using ADataType = ck::f8_t; + using AComputeType = ck::f8_t; + using BDataType = ck::f8_t; + using BComputeType = ck::f8_t; + using CShuffleDataType = ck::f8_t; + using DsDataTypes = ck::Tuple<>; + using AccDataType = float; + using EDataType = ck::f8_t; +}; + template struct ElementwiseOps { @@ -302,49 +315,31 @@ struct BlockGemmSpec }; template -constexpr BlockGemmSpec SetBlockGemm() +consteval BlockGemmSpec SetBlockGemm() { constexpr auto& BG = ALGORITHM.block_gemm; ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; - if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE) + switch(BG.scheduler) { - scheduler = ck::BlockGemmPipelineScheduler::Intrawave; - } - else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE) - { - scheduler = ck::BlockGemmPipelineScheduler::Interwave; - } - else - { - static_assert(false, "Unknown PipelineScheduler"); + case PipelineScheduler::INTRAWAVE: scheduler = ck::BlockGemmPipelineScheduler::Intrawave; break; + case PipelineScheduler::INTERWAVE: scheduler = ck::BlockGemmPipelineScheduler::Interwave; break; + case PipelineScheduler::DEFAULT: throw "Block GEMM scheduler must be Intrawave or Interwave."; + default: throw "Unknown PipelineScheduler"; } - if constexpr(BG.pipeline_version == PipelineVersion::V1) + switch(BG.pipeline_version) { - version = ck::BlockGemmPipelineVersion::v1; - } - else if constexpr(BG.pipeline_version == PipelineVersion::V2) - { - version = ck::BlockGemmPipelineVersion::v2; - } - else if constexpr(BG.pipeline_version == PipelineVersion::V3) - { - version = ck::BlockGemmPipelineVersion::v3; - } - else if constexpr(BG.pipeline_version == PipelineVersion::V4) - { - version = ck::BlockGemmPipelineVersion::v4; - } - else if constexpr(BG.pipeline_version == PipelineVersion::V5) - { - version = ck::BlockGemmPipelineVersion::v5; - } - else - { - static_assert(false, "Unknown PipelineVersion"); + case PipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break; + case PipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break; + case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; + case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; + case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM."; + default: throw "Unknown PipelineVersion"; } return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; @@ -453,18 +448,13 @@ template consteval ck::LoopScheduler SetLoopScheduler() { constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; - - if constexpr(loop_scheduler == PipelineScheduler::DEFAULT) + using ck_loop_sched = ck::LoopScheduler; + switch(loop_scheduler) { - return ck::LoopScheduler::Default; - } - else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE) - { - return ck::LoopScheduler::Interwave; - } - else - { - static_assert(false, "Unknown PipelineScheduler"); + case PipelineScheduler::DEFAULT: return ck_loop_sched::Default; + case PipelineScheduler::INTERWAVE: return ck_loop_sched::Interwave; + case PipelineScheduler::INTRAWAVE: throw "LoopScheduler must be either DEFAULT or INTERWAVE."; + default: throw "Unknown PipelineScheduler"; } } @@ -472,29 +462,16 @@ template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; - if constexpr(pipeline_version == PipelineVersion::V1) + using ck_pipeline = ck::PipelineVersion; + switch(pipeline_version) { - return ck::PipelineVersion::v1; - } - else if constexpr(pipeline_version == PipelineVersion::V2) - { - return ck::PipelineVersion::v2; - } - else if constexpr(pipeline_version == PipelineVersion::V3) - { - static_assert(false, "V3 is used only for stream-K."); - } - else if constexpr(pipeline_version == PipelineVersion::V4) - { - return ck::PipelineVersion::v4; - } - else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY) - { - return ck::PipelineVersion::weight_only; - } - else - { - static_assert(false, "Unknown PipelineVersion"); + case PipelineVersion::V1: return ck_pipeline::v1; + case PipelineVersion::V2: return ck_pipeline::v2; + case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K."; + case PipelineVersion::V4: return ck_pipeline::v4; + case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM."; + case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; + default: throw "Unknown GridwiseGemmPipelineVersion"; } } @@ -502,74 +479,27 @@ template consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization() { constexpr auto gemm_spec = ALGORITHM.gemm_specialization; + using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization; - if constexpr(gemm_spec == GemmSpecialization::Default) + switch(gemm_spec) { - return ck::tensor_operation::device::GemmSpecialization::Default; - } - else if constexpr(gemm_spec == GemmSpecialization::MPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::KPadding) - { - return ck::tensor_operation::device::GemmSpecialization::KPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MKPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MKPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NKPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NKPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNKPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNKPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::OPadding) - { - return ck::tensor_operation::device::GemmSpecialization::OPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::KOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::KOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MKOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MKOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NKOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NKOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNKOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNKOPadding; - } - else - { - static_assert(false, "Unknown GemmSpecialization"); + case GemmSpecialization::Default: return ck_gemm_spec::Default; + case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding; + case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding; + case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding; + case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding; + case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding; + case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding; + case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding; + case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding; + case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding; + case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding; + case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding; + case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding; + case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding; + case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding; + case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding; + default: throw "Unknown GemmSpecialization"; } } @@ -577,30 +507,15 @@ template consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { constexpr auto version = ALGORITHM.pipeline_version; - - if constexpr(version == PipelineVersion::V1) + using ck_pipeline = ck::BlockGemmPipelineVersion; + switch(version) { - return ck::BlockGemmPipelineVersion::v1; - } - else if constexpr(version == PipelineVersion::V2) - { - return ck::BlockGemmPipelineVersion::v2; - } - else if constexpr(version == PipelineVersion::V3) - { - return ck::BlockGemmPipelineVersion::v3; - } - else if constexpr(version == PipelineVersion::V4) - { - return ck::BlockGemmPipelineVersion::v4; - } - else if constexpr(version == PipelineVersion::V5) - { - return ck::BlockGemmPipelineVersion::v5; - } - else - { - static_assert(false, "Unknown PipelineVersion"); + case PipelineVersion::V1: return ck_pipeline::v1; + case PipelineVersion::V2: return ck_pipeline::v2; + case PipelineVersion::V3: return ck_pipeline::v3; + case PipelineVersion::V4: return ck_pipeline::v4; + case PipelineVersion::V5: return ck_pipeline::v5; + default: throw "Unknown block GEMM PipelineVersion"; } } @@ -608,26 +523,14 @@ template consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization() { constexpr auto specialization = ALGORITHM.fwd_specialization; - - if constexpr(specialization == ConvFwdSpecialization::DEFAULT) + using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; + switch(specialization) { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - } - else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - } - else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - } - else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3) - { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3; - } - else - { - static_assert(false, "Unknown ConvFwdSpecialization"); + case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; + default: throw "Unknown ConvFwdSpecialization"; } } @@ -639,7 +542,12 @@ namespace ck_tile::builder { template -struct ConvFactory; +struct ConvFactory +{ + // This will trigger if a specialization for the given convolution direction is not found. + // We should always catch this in an earlier validation check. + static_assert(false, "Unsupported device operation."); +}; // Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance // of a grouped forward convolution kernel. @@ -647,7 +555,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -658,26 +566,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseXdlGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesBlockGemm, - "The convolution algorithm descriptor must specify block gemm pipeline."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load == ALGORITHM.block_transfer.lds_transfer_b.is_direct_load, "A and B block transfers must both be direct load or not."); @@ -769,7 +657,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -780,31 +668,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseXdlGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); - static_assert(SpecifiesNumPrefetchStages, - "The convolution algorithm descriptor must specify number of prefetch stages."); - static_assert(SpecifiesLoopScheduler, - "The convolution algorithm descriptor must specify loop scheduler."); - static_assert(SpecifiesNumGroupsToMerge, - "The convolution algorithm descriptor must specify number of groups to merge."); - static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = @@ -891,7 +754,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -902,27 +765,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseWmmaGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesNumPrefetchStages, - "The convolution algorithm descriptor must specify number of prefetch stages."); - static_assert(SpecifiesLoopScheduler, - "The convolution algorithm descriptor must specify loop scheduler."); - static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = @@ -1008,7 +850,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -1019,24 +861,6 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); - static_assert(SpecifiesDlThreadConfig, - "DL algorithm must specify thread config."); - static_assert(SpecifiesDlThreadCluster, - "DL algorithm must specify thread cluster."); - static_assert(SpecifiesDlBlockTransferA, - "DL algorithm must specify A block transfer."); - static_assert(SpecifiesDlBlockTransferB, - "DL algorithm must specify B block transfer."); - static_assert(SpecifiesDlCThreadTransfer, - "DL algorithm must specify C thread transfer."); - static constexpr auto FWD_CONV_SPECIALIZATION = factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = @@ -1045,7 +869,7 @@ struct ConvFactory static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); // DL-specific parameters from algorithm descriptor - static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config; + static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config; static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; @@ -1053,12 +877,12 @@ struct ConvFactory static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; // Thread cluster from descriptor - static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster; + static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster; using M1N1ThreadClusterM1Xs = to_sequence_v; using M1N1ThreadClusterN1Xs = to_sequence_v; // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format - static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a; + static constexpr auto DL_A_TRANSFER = ALGORITHM.block_transfer_a; using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = to_sequence_v; using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = @@ -1074,7 +898,7 @@ struct ConvFactory to_sequence_v; // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format - static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b; + static constexpr auto DL_B_TRANSFER = ALGORITHM.block_transfer_b; using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = to_sequence_v; using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = @@ -1090,7 +914,7 @@ struct ConvFactory to_sequence_v; // C Thread Transfer from descriptor - static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer; + static constexpr auto DL_C_TRANSFER = ALGORITHM.epilogue_c; using CThreadTransferSrcDstAccessOrder = to_sequence_v; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; static constexpr ck::index_t CThreadTransferDstScalarPerVector = @@ -1148,7 +972,7 @@ template requires ConvDirectionIsForward && - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -1159,45 +983,24 @@ struct ConvFactory using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(SpecifiesThreadBlock, - "The convolution algorithm descriptor must specify thread block info."); - static_assert(SpecifiesGridwiseXdlGemm, - "The convolution algorithm descriptor must specify gridwise GEMM info."); - static_assert(SpecifiesBlockTransfer, - "The convolution algorithm descriptor must specify block transfer info."); - static_assert(SpecifiesLdsTransfer, - "The convolution algorithm descriptor must specify LDS transfer info."); - static_assert( - SpecifiesThreadClusterAccessOrder, - "The convolution algorithm descriptor must specify thread cluster access order info."); - static_assert(SpecifiesSourceAccessOrder, - "The convolution algorithm descriptor must specify source access order info."); - static_assert(SpecifiesFwdConcSpecialization, - "The convolution algorithm descriptor must specify forward convolution " - "specialization."); - static_assert(SpecifiesGemmSpecialization, - "The convolution algorithm descriptor must specify gemm specialization."); - static_assert(SpecifiesNumPrefetchStages, - "The convolution algorithm descriptor must specify number of prefetch stages."); - static_assert(SpecifiesLoopScheduler, - "The convolution algorithm descriptor must specify loop scheduler."); + static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; static constexpr auto FWD_CONV_SPECIALIZATION = - factory_internal::SetFwdConvSpecialization(); + factory_internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = - factory_internal::SetGemmSpecialization(); + factory_internal::SetGemmSpecialization(); static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; - static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); - static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvABlockTransfer(); + factory_internal::SetFwdConvABlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBBlockTransfer(); + factory_internal::SetFwdConvBBlockTransfer(); static constexpr auto C_BLOCK_TRANSFER = - factory_internal::SetCBlockTransfer(); + factory_internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. static_assert(InputVectorTransferLimits); @@ -1227,7 +1030,7 @@ struct ConvFactory typename Ops::CDEElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, + BASE_ALGORITHM.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 742dfbb89c..983273b439 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -21,7 +21,6 @@ #include #include "ck_tile/builder/types.hpp" -#include "ck_tile/builder/conv_signature_predicates.hpp" namespace ck_tile::builder { @@ -41,9 +40,6 @@ template concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); -template -concept ConvDeviceOp = std::same_as, GroupConvDeviceOp>; - template concept ConvLayout = std::same_as, GroupConvLayout>; @@ -55,7 +51,6 @@ concept ConvSignatureDescriptor = requires(T t) { { t.layout } -> ConvLayout; { t.data_type } -> std::convertible_to; { t.elementwise_operation } -> std::convertible_to; - { t.device_operation } -> ConvDeviceOp; }; // Concept to validate a convolution signature's values. @@ -63,7 +58,18 @@ template concept ValidConvSignature = requires { requires ConvSpatialDim; requires ConvDataType; - requires IsValidConvDeviceOp; }; +// Predicate for forward convolution. +template +concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); + +// Predicate for backward data convolution. +template +concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); + +// Predicate for backward weight convolution. +template +concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp deleted file mode 100644 index 3869c7b538..0000000000 --- a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/builder/types.hpp" - -namespace ck_tile::builder { - -/********************************************** - * Conv Direction Predicates - **********************************************/ - -// Predicate for forward convolution. -template -concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); - -// Predicate for backward data convolution. -template -concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); - -// Predicate for backward weight convolution. -template -concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); - -/********************************************** - * Conv Fwd Device Op Predicates - **********************************************/ - -// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); - -// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); - -// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); - -// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - ConvDirectionIsForward && - (Sig.device_operation._fwd == - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); - -// Generic predicate to check if signature uses any forward convolution device operation. -template -concept ConvDeviceOpIsForward = - ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 || - ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; - -/********************************************** - * Conv Bwd Weight Device Op Predicates - **********************************************/ - -// Predicate for DeviceGroupedConvBwdWeight operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); - -// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); - -// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); - -// Predicate for DeviceGroupedConvBwdWeightMultipleD operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); - -// Predicate for DeviceGroupedConvBwdWeight_Dl operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = - ConvDirectionIsBackwardWeight && - (Sig.device_operation._bwd_weight == - BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); - -// Generic predicate to check if signature uses any backward weight convolution device operation. -template -concept ConvDeviceOpIsBackwardWeight = - ConvDeviceOpIs_DeviceGroupedConvBwdWeight || - ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 || - ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD || - ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl; - -/********************************************** - * Conv Bwd Data Device Op Predicates - **********************************************/ - -// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = - ConvDirectionIsBackwardData && - (Sig.device_operation._bwd_data == - BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); - -// Predicate for DeviceGroupedConvBwdDataMultipleD operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = - ConvDirectionIsBackwardData && - (Sig.device_operation._bwd_data == - BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); - -// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. -template -concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = - ConvDirectionIsBackwardData && - (Sig.device_operation._bwd_data == - BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); - -// Generic predicate to check if signature uses any backward data convolution device operation. -template -concept ConvDeviceOpIsBackwardData = - ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 || - ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD || - ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; - -/********************************************** - * Generic Device Op Predicates - **********************************************/ - -// Generic predicate to check if signature uses any device operation. -template -concept IsValidConvDeviceOp = ConvDeviceOpIsForward || ConvDeviceOpIsBackwardData || - ConvDeviceOpIsBackwardWeight; - -} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/device_op_types.hpp b/experimental/builder/include/ck_tile/builder/device_op_types.hpp deleted file mode 100644 index 0e779fdf4e..0000000000 --- a/experimental/builder/include/ck_tile/builder/device_op_types.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -namespace ck_tile::builder { - -// Enumeration for CK Device Operation types. -// This allows the builder to select which device operation template to instantiate -// based on the user's requirements. -enum class DeviceOpType -{ - // Forward Convolution - Non-grouped - CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet) - - // Forward Convolution - Grouped - GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle - GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to: - // DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 -}; - -} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 8756825c3f..5c267f2552 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -14,6 +14,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index a58c994288..fa2b99ef56 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -74,52 +74,6 @@ enum class ConvDirection BACKWARD_WEIGHT }; -// Forward convolution device operations. -enum class FwdGroupConvDeviceOperation -{ - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor -}; - -// Backward data convolution device operations. -enum class BwdDataGroupConvDeviceOperation -{ - DeviceGroupedConvBwdDataMultipleD, - DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 -}; - -// Backward weight convolution device operations. -enum class BwdWeightGroupConvDeviceOperation -{ - DeviceGroupedConvBwdWeight, - DeviceGroupedConvBwdWeight_Dl, - DeviceGroupedConvBwdWeight_Xdl_CShuffle, - DeviceGroupedConvBwdWeight_Xdl_CShuffleV3, - DeviceGroupedConvBwdWeight_Wmma_CShuffle, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, - DeviceGroupedConvBwdWeightMultipleD, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, -}; - -// Structural type for device operation -struct GroupConvDeviceOp -{ - union - { - FwdGroupConvDeviceOperation _fwd; - BwdDataGroupConvDeviceOperation _bwd_data; - BwdWeightGroupConvDeviceOperation _bwd_weight; - }; - - constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {} - constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {} - constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {} -}; - // Fused element-wise operations. enum class ElementwiseOperation { @@ -219,6 +173,11 @@ enum class PipelineScheduler INTERWAVE }; +enum class ConvAlgorithmSpecialization +{ + LARGE_TENSOR +}; + // ostream operator overloads for enum classes inline std::ostream& operator<<(std::ostream& os, DataType dt) { @@ -286,61 +245,6 @@ inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout) } } -inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op) -{ - using enum FwdGroupConvDeviceOperation; - switch(op) - { - case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK: - return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"; - case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle: - return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"; - case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle: - return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"; - case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3: - return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"; - case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor: - return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"; - default: return os << "Unknown"; - } -} - -inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op) -{ - using enum BwdDataGroupConvDeviceOperation; - switch(op) - { - case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD"; - case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle: - return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"; - case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1: - return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"; - default: return os << "Unknown"; - } -} - -inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op) -{ - using enum BwdWeightGroupConvDeviceOperation; - switch(op) - { - case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight"; - case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl"; - case DeviceGroupedConvBwdWeight_Xdl_CShuffle: - return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; - case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3: - return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; - case DeviceGroupedConvBwdWeight_Wmma_CShuffle: - return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; - case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle: - return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"; - case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD"; - case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle: - return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; - default: return os << "Unknown"; - } -} - inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op) { using enum ElementwiseOperation; diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 43c4fd4857..5044d223de 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -39,7 +39,8 @@ add_ck_builder_test(test_ckb_get_instance_string add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_1d_fp16.cpp conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp + conv/test_ckb_conv_fwd_1d_i8.cpp + conv/test_ckb_conv_fwd_2d_fp8.cpp conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -47,7 +48,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp conv/test_ckb_conv_fwd_3d_bf16.cpp conv/test_ckb_conv_fwd_3d_fp16.cpp - conv/test_ckb_conv_fwd_3d_fp32.cpp) + conv/test_ckb_conv_fwd_3d_fp32.cpp + ) function(add_ck_factory_test test_name) add_ck_builder_test(${test_name} ${ARGN}) @@ -66,10 +68,10 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) -add_ck_builder_test(test_conv_traits +add_ck_builder_test(test_ckb_conv_traits conv/test_conv_traits.cpp) -add_ck_builder_test(test_conv_description +add_ck_builder_test(test_ckb_conv_description test_conv_description.cpp) # Function to add all test_ckb targets to a list diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index 123034eb77..bb0c767bbd 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -1,34 +1,39 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 1D BF16 (channels-first) with Pipeline V2 and FILTER_1X1_STRIDE1_PAD0 specialization and SCALE // elementwise op TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, - .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::SCALE, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, + .data_type = DataType::BF16, + .elementwise_operation = ElementwiseOperation::SCALE}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v2_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V2, - ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Filter1x1Stride1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v2"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp index a83ca84297..d391c1a74d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp @@ -1,31 +1,35 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 1D FP16 (channels-last) with DEFAULT specialization TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout1D::NWGC_GKXC_NWGK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout1D::NWGC_GKXC_NWGK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 64, - .tile_size = {.m = 64, .n = 32, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle(); + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "64, 64, 32, 32", "Default"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp index 3ceac2a047..7206c768d8 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp @@ -1,31 +1,35 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 1D I8 (channels-last) with and DEFAULT specialization TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 1, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout1D::GNWC_GKXC_GNWK, - .data_type = DataType::I8, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout1D::GNWC_GKXC_GNWK, + .data_type = DataType::I8, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 128, - .tile_size = {.m = 64, .n = 64, .k = 64}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} + .with_thread_block(FwdThreadBlock_128_64x64x64) + .with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave) + .with_block_transfer(FwdBlockTransfer_4x32x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); - run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle(); + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle", "128, 64, 64, 64", "Default"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 240746f546..d0bc1d7a6d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -1,54 +1,63 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, - .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::BF16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v1_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Default", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v1"}); } // 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3 TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, - .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::BF16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v5_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "Filter3x3", + "BlkGemmPipelineVersion: v5"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp index 12730bab19..0a337f3a7b 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -1,69 +1,59 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 16}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} + .with_thread_block(FwdThreadBlock_256_128x128x16) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) + .with_dl_thread_cluster(DlThreadCluster_8x2) + .with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB) + .with_dl_epilogue(DlEpilogueC); - run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); -} - -TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC) -{ - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; - - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 16}}; - - run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Default"}); } TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 16}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} + .with_thread_block(FwdThreadBlock_256_128x128x16) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) + .with_dl_thread_cluster(DlThreadCluster_8x2) + .with_dl_block_transfer(DlBlockTransferAB, DlBlockTransferAB) + .with_dl_epilogue(DlEpilogueC); - run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< - FwdConvSignature, - FwdThreadBlock, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Filter1x1Pad0"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index 6366016707..798c6cfa2d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -1,32 +1,38 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v3_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V3, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Filter1x1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v3"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index 7b303a7bde..a8313ff510 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -1,32 +1,38 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, - .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, + .data_type = DataType::FP32, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v4_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V4, - ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 128, 128, 32", + "Filter1x1Stride1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v4"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp new file mode 100644 index 0000000000..39319bb79e --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp @@ -0,0 +1,35 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +// 2D FP8 NHWGC (channels-last) with Pipeline V1 and DEFAULT +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::FP8, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + .with_thread_block(FwdThreadBlock_256_256x128x32) + .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1_fp8) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + + using Builder = ConvBuilder; + run_test( + {"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle", "256, 256, 128, 32", "Default"}); +} + +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 0216c5907d..6c43678bf1 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -1,53 +1,64 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ + .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + .with_thread_block(FwdThreadBlock_256_256x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, + GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; - run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - FwdConvSignature, - FwdThreadBlock, - ConvFwdSpecialization::DEFAULT>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", + "256, 256, 128, 32", + "Default"}); } TEST( FwdConvInstances, Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 128, - .tile_size = {.m = 128, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ + .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} + .with_thread_block(FwdThreadBlock_128_128x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_block_transfer(FwdBlockTransfer_4x16x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; - run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - FwdConvSignature, - FwdThreadBlock, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor", + "128, 128, 128, 32", + "Filter1x1Pad0"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index b40dd0b0d7..2392d1efff 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -1,32 +1,38 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, - .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, + .data_type = DataType::BF16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v3_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Default", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v3"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index e0dad4e1a1..52e153098e 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -1,33 +1,39 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0 TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, - .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, + .data_type = DataType::FP16, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v4_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V4, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 128, 128, 32", + "Filter1x1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v4"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 43ffb3f89a..5d1656924c 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -1,33 +1,39 @@ // Copyright (C) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "utils/ckb_conv_test_common.hpp" +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { using namespace ck_tile::builder::test_utils; -namespace ck_tile::builder::testing { - // 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0 TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{ - .spatial_dim = 3, - .direction = ConvDirection::FORWARD, - .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, - .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH, - .device_operation = - FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; + constexpr ConvSignature FwdConvSignature{.spatial_dim = 3, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, + .data_type = DataType::FP32, + .elementwise_operation = + ElementwiseOperation::PASS_THROUGH}; - constexpr ThreadBlock FwdThreadBlock{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} + .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) + .with_block_transfer(FwdBlockTransfer_4x64x1) + .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_block_gemm(BlockGemmDesc_v1_intrawave); - run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - FwdConvSignature, - FwdThreadBlock, - PipelineVersion::V1, - ConvFwdSpecialization::FILTER_1X1_PAD0>(); + using Builder = ConvBuilder; + run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + "256, 256, 256, 32", + "Filter1x1Pad0", + "BlkGemmPipelineScheduler: Intrawave", + "BlkGemmPipelineVersion: v1"}); } -} // namespace ck_tile::builder::testing +} // namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 88c5b5787a..fd756cf06e 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -117,103 +117,6 @@ struct BlockTransferABC AccessOrder src_access_order_b; }; -struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 -{ - ThreadBlock thread_block; - GridwiseXdlGemm gridwise_gemm; - BlockTransferABC block_transfer; - ConvFwdSpecialization fwd_specialization; - GemmSpecialization gemm_specialization; - BlockGemm block_gemm; -}; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert( - ckb::SpecifiesThreadBlock); -static_assert( - ckb::SpecifiesGridwiseXdlGemm); -static_assert( - ckb::SpecifiesBlockTransfer); -static_assert( - ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); -static_assert( - ckb::SpecifiesSourceAccessOrder); -static_assert(ckb::SpecifiesFwdConcSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); -static_assert( - ckb::SpecifiesBlockGemm); -static_assert(ckb::SpecifiesGemmSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>); - -struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle -{ - ThreadBlock thread_block; - GridwiseXdlGemm gridwise_gemm; - BlockTransferABC block_transfer; - ConvFwdSpecialization fwd_specialization; - GemmSpecialization gemm_specialization; - size_t num_gemm_k_prefetch_stages; - size_t num_groups_to_merge; - PipelineScheduler loop_scheduler; -}; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert( - ckb::SpecifiesThreadBlock); -static_assert( - ckb::SpecifiesGridwiseXdlGemm); -static_assert( - ckb::SpecifiesBlockTransfer); -static_assert( - ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>); -static_assert( - ckb::SpecifiesSourceAccessOrder); -static_assert(ckb::SpecifiesFwdConcSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>); -static_assert( - ckb::SpecifiesNumPrefetchStages); -static_assert( - ckb::SpecifiesGemmSpecialization); -static_assert( - ckb::SpecifiesLoopScheduler); -static_assert( - ckb::SpecifiesNumGroupsToMerge); - -struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle -{ - ThreadBlock thread_block; - GridwiseWmmaGemm gridwise_gemm; - BlockTransferABC block_transfer; - ConvFwdSpecialization fwd_specialization; - GemmSpecialization gemm_specialization; - size_t num_gemm_k_prefetch_stages; - PipelineScheduler loop_scheduler; -}; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert(ckb::SpecifiesThreadBlock); -static_assert( - ckb::SpecifiesGridwiseWmmaGemm); -static_assert( - ckb::SpecifiesBlockTransfer); -static_assert(ckb::SpecifiesLdsTransfer); -static_assert(ckb::SpecifiesThreadClusterAccessOrder< - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>); -static_assert( - ckb::SpecifiesSourceAccessOrder); -static_assert( - ckb::SpecifiesFwdConcSpecialization); -static_assert( - ckb::SpecifiesNumPrefetchStages); -static_assert( - ckb::SpecifiesGemmSpecialization); -static_assert( - ckb::SpecifiesLoopScheduler); - // DL-specific descriptors struct DlThreadConfig { @@ -227,12 +130,12 @@ static_assert(ckb::DlThreadConfigDescriptor); struct DlThreadCluster { - std::array m1_xs; // e.g., {8, 2} - std::array n1_xs; // e.g., {8, 2} + std::array m1_xs; + std::array n1_xs; }; static_assert(ckb::DlThreadClusterDescriptor); -struct DlBlockTransferK0M0M1K1 +struct DlBlockTransfer { std::array thread_slice_lengths; std::array thread_cluster_lengths; @@ -242,56 +145,212 @@ struct DlBlockTransferK0M0M1K1 std::array src_vector_tensor_contiguous_dim_order; std::array dst_vector_tensor_lengths; }; -static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor); +static_assert(ckb::DlBlockTransferDescriptor); -struct DlBlockTransferK0N0N1K1 -{ - std::array thread_slice_lengths; - std::array thread_cluster_lengths; - std::array thread_cluster_arrange_order; - std::array src_access_order; - std::array src_vector_tensor_lengths; - std::array src_vector_tensor_contiguous_dim_order; - std::array dst_vector_tensor_lengths; -}; -static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor); - -struct DlCThreadTransfer +struct DlEpilogue { std::array src_dst_access_order; size_t src_dst_vector_dim; size_t dst_scalar_per_vector; }; -static_assert(ckb::DlCThreadTransferDescriptor); +static_assert(ckb::DlEpilogueDescriptor); -struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +// Factory types + +struct ThreadBlock_ { ThreadBlock thread_block; +}; + +struct XdlGemm_ +{ + GridwiseXdlGemm gridwise_gemm; +}; + +struct WmmaGemm_ +{ + GridwiseWmmaGemm gridwise_gemm; +}; + +struct BlockTransfer_ +{ + BlockTransferABC block_transfer; +}; + +struct ConvSpecialization_ +{ ConvFwdSpecialization fwd_specialization; GemmSpecialization gemm_specialization; - DlThreadConfig dl_thread_config; - DlThreadCluster dl_thread_cluster; - DlBlockTransferK0M0M1K1 dl_block_transfer_a; - DlBlockTransferK0N0N1K1 dl_block_transfer_b; - DlCThreadTransfer dl_c_thread_transfer; }; -static_assert( - ckb::ConvAlgorithmDescriptor); -static_assert( - ckb::SpecifiesThreadBlock); -static_assert(ckb::SpecifiesFwdConcSpecialization< - ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>); -static_assert( - ckb::SpecifiesGemmSpecialization); -static_assert( - ckb::SpecifiesDlThreadConfig); -static_assert( - ckb::SpecifiesDlThreadCluster); -static_assert( - ckb::SpecifiesDlBlockTransferA); -static_assert( - ckb::SpecifiesDlBlockTransferB); -static_assert( - ckb::SpecifiesDlCThreadTransfer); + +struct Prefetch_ +{ + size_t num_gemm_k_prefetch_stages; + size_t num_groups_to_merge; + PipelineScheduler loop_scheduler; +}; + +struct BlockGemm_ +{ + BlockGemm block_gemm; +}; + +struct DlThreadConfig_ +{ + DlThreadConfig thread_config; +}; + +struct DlThreadCluster_ +{ + DlThreadCluster thread_cluster; +}; + +struct DlBlockTransfer_ +{ + DlBlockTransfer block_transfer_a; + DlBlockTransfer block_transfer_b; +}; + +struct DlEpilogue_ +{ + DlEpilogue epilogue_c; +}; + +// Specialization wrapper for large tensor support +template +struct LargeTensorWrapper +{ + BaseAlgorithm base_algorithm; + static constexpr ConvAlgorithmSpecialization specialization = + ConvAlgorithmSpecialization::LARGE_TENSOR; +}; + +// Factory + +template +struct ConvAlgorithmTemplate : Components... +{ + + template + constexpr auto with_thread_block(const TB& tb) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_block = tb; + return result; + } + + template + constexpr auto with_gemm_config(const GemmConfig& gemm) const + { + auto result = *this; + if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } + else if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } + return result; + } + + template + constexpr auto with_block_transfer(const BT& bt) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_transfer = bt; + return result; + } + + constexpr auto with_specializations(ConvFwdSpecialization fwd_spec, + GemmSpecialization gemm_spec) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.fwd_specialization = fwd_spec; + result.gemm_specialization = gemm_spec; + return result; + } + + constexpr auto with_prefetch_config(size_t k_prefetch_stages, + size_t groups_to_merge, + PipelineScheduler scheduler) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.num_gemm_k_prefetch_stages = k_prefetch_stages; + result.num_groups_to_merge = groups_to_merge; + result.loop_scheduler = scheduler; + return result; + } + + template + constexpr auto with_block_gemm(const BG& bg) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_gemm = bg; + return result; + } + + template + constexpr auto with_dl_thread_config(const TC& tc) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_config = tc; + return result; + } + + template + constexpr auto with_dl_thread_cluster(const TCl& tcl) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_cluster = tcl; + return result; + } + + template + constexpr auto with_dl_block_transfer(const BTA& bta, const BTB& btb) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_transfer_a = bta; + result.block_transfer_b = btb; + return result; + } + + constexpr auto with_dl_epilogue(const DlEpilogue& epi) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.epilogue_c = epi; + return result; + } +}; + +// Algorithm types + +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + ConvAlgorithmTemplate; + +using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + ConvAlgorithmTemplate; + +using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + ConvAlgorithmTemplate; +using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + ConvAlgorithmTemplate; + +using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + LargeTensorWrapper; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index 5e6684c4cd..71f16aefbe 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -17,7 +17,6 @@ struct ConvSignature GroupConvLayout layout; DataType data_type; ElementwiseOperation elementwise_operation; - GroupConvDeviceOp device_operation; }; static_assert(ConvSignatureDescriptor); diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 97af4af795..733359d491 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -25,8 +25,6 @@ struct ConvSignature ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; ckb::DataType data_type = ckb::DataType::FP16; ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH; - ckb::GroupConvDeviceOp device_operation = - ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; }; static_assert(ckb::ConvSignatureDescriptor); diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp deleted file mode 100644 index 14fae566f6..0000000000 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ /dev/null @@ -1,383 +0,0 @@ -// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include "impl/conv_algorithm_types.hpp" -#include "impl/conv_signature_types.hpp" -#include "ck_tile/builder/conv_builder.hpp" - -namespace ck_tile::builder::test_utils { - -using namespace ck_tile::builder; -using namespace test; - -// Common test implementation -template -constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() -{ - constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 32, - .n_per_xdl = 32, - .m_xdl_per_wave = 4, - .n_xdl_per_wave = 4}; - - constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 8}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; - - constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion, - .scheduler = PipelineScheduler::INTRAWAVE}; - - constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .gridwise_gemm = FwdGemmParams, - .block_transfer = FwdBlockTransfer, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .block_gemm = BlockGemmDesc}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")); - - // Verify pipeline version is correct - if(FwdPipelineVersion == PipelineVersion::V1) - EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos); - else if(FwdPipelineVersion == PipelineVersion::V3) - EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos); - else if(FwdPipelineVersion == PipelineVersion::V4) - EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos); - else if(FwdPipelineVersion == PipelineVersion::V5) - EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -template -constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle() -{ - constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 32, - .n_per_xdl = 32, - .m_xdl_per_wave = 2, - .n_xdl_per_wave = 1}; - - constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 16, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; - - constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .gridwise_gemm = FwdGemmParams, - .block_transfer = FwdBlockTransfer, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .num_gemm_k_prefetch_stages = 1, - .num_groups_to_merge = 2, - .loop_scheduler = PipelineScheduler::DEFAULT}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle")); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -template -constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() -{ - constexpr GridwiseWmmaGemm FwdGemmParams{.k1 = 8, - .m_per_wmma = 32, - .n_per_wmma = 32, - .m_wmma_per_wave = 2, - .n_wmma_per_wave = 1, - .pipeline_version = PipelineVersion::V1}; - - constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; - - constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .gridwise_gemm = FwdGemmParams, - .block_transfer = FwdBlockTransfer, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .num_gemm_k_prefetch_stages = 1, - .loop_scheduler = PipelineScheduler::DEFAULT}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Wmma_CShuffle")); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -template -constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK() -{ - // DL thread configuration - constexpr DlThreadConfig DlThreadCfg{ - .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; - - // DL thread cluster - constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; - - // DL A block transfer - K0_M0_M1_K1 format - constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{ - .thread_slice_lengths = {8, 1, 1, 2}, - .thread_cluster_lengths = {2, 1, 128, 1}, - .thread_cluster_arrange_order = {1, 2, 0, 3}, - .src_access_order = {1, 2, 0, 3}, - .src_vector_tensor_lengths = {4, 1, 1, 2}, - .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, - .dst_vector_tensor_lengths = {1, 1, 1, 2}}; - - // DL B block transfer - K0_N0_N1_K1 format - constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{ - .thread_slice_lengths = {8, 1, 1, 2}, - .thread_cluster_lengths = {2, 1, 128, 1}, - .thread_cluster_arrange_order = {1, 2, 0, 3}, - .src_access_order = {1, 2, 0, 3}, - .src_vector_tensor_lengths = {4, 1, 1, 2}, - .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, - .dst_vector_tensor_lengths = {1, 1, 1, 2}}; - - // DL C thread transfer - constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, - .src_dst_vector_dim = 5, - .dst_scalar_per_vector = 4}; - - constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .dl_thread_config = DlThreadCfg, - .dl_thread_cluster = DlCluster, - .dl_block_transfer_a = DlBlockTransferA, - .dl_block_transfer_b = DlBlockTransferB, - .dl_c_thread_transfer = DlCTransfer}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK")); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor -// Note: Large_Tensor has identical parameters to regular XDL CShuffle -template -constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor() -{ - constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 32, - .n_per_xdl = 32, - .m_xdl_per_wave = 2, - .n_xdl_per_wave = 1}; - - constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 16, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; - - // Large_Tensor uses the same descriptor as regular XDL CShuffle - constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ - .thread_block = FwdThreadBlock, - .gridwise_gemm = FwdGemmParams, - .block_transfer = FwdBlockTransfer, - .fwd_specialization = FwdConvSpecialization, - .gemm_specialization = GemmSpecialization::MNKPadding, - .num_gemm_k_prefetch_stages = 1, - .num_groups_to_merge = 1, - .loop_scheduler = LoopScheduler::DEFAULT}; - - using Builder = ConvBuilder; - - auto instance = typename Builder::Instance{}; - - const auto kernel_string = instance.GetTypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; - EXPECT_GT(kernel_string.size(), 0); - - EXPECT_TRUE( - kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor")); - - // Verify specialization is correct - if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) - EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); - else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) - EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); - - const auto invoker_ptr = instance.MakeInvokerPointer(); - EXPECT_NE(invoker_ptr, nullptr); -} - -} // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp new file mode 100644 index 0000000000..7f2acce9c8 --- /dev/null +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -0,0 +1,184 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "impl/conv_algorithm_types.hpp" +#include "impl/conv_signature_types.hpp" +#include "ck_tile/builder/conv_builder.hpp" + +namespace ck_tile::builder::test_utils { + +using namespace ck_tile::builder; +using namespace test; + +constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{ + .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; + +constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; + +constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + +constexpr DlEpilogue DlEpilogueC{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}; + +constexpr BlockTransferABC FwdBlockTransfer_4x64x1{ + .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + +constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{ + .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + +constexpr BlockTransferABC FwdBlockTransfer_4x16x1{ + .block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 16, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + +constexpr BlockTransferABC FwdBlockTransfer_4x32x1{ + .block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}; + +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}; + +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}; + +constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, + .m_per_wmma = 32, + .n_per_wmma = 32, + .m_wmma_per_wave = 2, + .n_wmma_per_wave = 1, + .pipeline_version = PipelineVersion::V1}; + +constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + +constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64, + .tile_size = {.m = 64, .n = 32, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128, + .tile_size = {.m = 128, .n = 128, .k = 32}}; + +constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128, + .tile_size = {.m = 64, .n = 64, .k = 64}}; + +constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; + +} // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_test_utils.hpp b/experimental/builder/test/utils/ckb_conv_test_utils.hpp new file mode 100644 index 0000000000..6f9442d067 --- /dev/null +++ b/experimental/builder/test/utils/ckb_conv_test_utils.hpp @@ -0,0 +1,31 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +namespace ck_tile::builder::test_utils { +using namespace test; + +// Common test implementation +template +constexpr void run_test(const std::vector& kernel_instance_components) +{ + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); + + for(const auto& component : kernel_instance_components) + { + EXPECT_THAT(kernel_string, ::testing::HasSubstr(component)); + } +} + +} // namespace ck_tile::builder::test_utils