From 2460cf4579b7a5353bf22d43d4aa5fbfc868e484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 19 Dec 2025 07:59:37 -0500 Subject: [PATCH] Initial conv bwd weight factory. --- .../builder/conv_algorithm_concepts.hpp | 19 +++- .../factory/conv_bwd_weight_xdl_factory.hpp | 102 ++++++++++++++++++ .../builder/factory/conv_dispatcher.hpp | 69 +++++++----- .../builder/factory/conv_fwd_dl_factory.hpp | 2 +- .../factory/conv_fwd_large_tensor_factory.hpp | 2 +- .../builder/factory/conv_fwd_v3_factory.hpp | 2 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 2 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 2 +- .../helpers/ck/conv_elementwise_op.hpp | 13 +++ .../factory/helpers/ck/conv_tensor_layout.hpp | 19 +++- .../factory/helpers/ck/conv_tensor_type.hpp | 21 ++++ .../factory/helpers/ck/conv_tuning_params.hpp | 27 +++-- .../builder/include/ck_tile/builder/types.hpp | 4 +- .../test/conv/ck/test_ckb_conv_bwd_weight.cpp | 4 +- .../test/impl/conv_algorithm_types.hpp | 84 ++++++++++++--- 15 files changed, 310 insertions(+), 62 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index bf7e89fcaa..fcc9a09c6d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -27,8 +27,6 @@ concept ThreadBlockDescriptor = requires(T t) { // Concept for parameters that describe a gridwise XDL GEMM problem. template concept GridwiseXdlGemmDescriptor = requires(T t) { - { t.ak1 } -> std::convertible_to; - { t.bk1 } -> std::convertible_to; { t.m_per_xdl } -> std::convertible_to; { t.n_per_xdl } -> std::convertible_to; { t.m_xdl_per_wave } -> std::convertible_to; @@ -159,7 +157,17 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseXdlGemm = requires { +concept SpecifiesGridwiseFwdXdlGemm = requires { + { T::gridwise_gemm.ak1 } -> std::convertible_to; + { T::gridwise_gemm.bk1 } -> std::convertible_to; + { T::gridwise_gemm } -> GridwiseXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseBwdXdlGemm = requires { + { T::gridwise_gemm.k0_per_block } -> std::convertible_to; + { T::gridwise_gemm.k1 } -> std::convertible_to; { T::gridwise_gemm } -> GridwiseXdlGemmDescriptor; }; @@ -247,6 +255,11 @@ concept SpecifiesFwdConvSpecialization = requires { { T::fwd_specialization } -> std::convertible_to; }; +template +concept SpecifiesBwdWeightConvSpecialization = requires { + { T::bwd_weight_specialization } -> std::convertible_to; +}; + template concept SpecifiesGemmSpecialization = requires { { T::gemm_specialization } -> std::convertible_to; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp new file mode 100644 index 0000000000..9889356092 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -0,0 +1,102 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k0_per_block, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::InComputeType, + typename Types::WeiComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 99e7479e36..b18d54f489 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -60,6 +60,7 @@ #include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" #include "ck_tile/builder/factory/conv_tile_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weigth_xdl_factory.hpp" namespace ck_tile::builder::factory { @@ -88,34 +89,43 @@ concept IsTileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock SpecifiesTileTransfer && SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && SpecifiesTileOptimizations; +template +concept SpecifiesDataTransfer = + SpecifiesThreadBlock && SpecifiesBlockTransfer && + SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder; + // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template -concept IsXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesBlockGemm; +concept IsFwdXdlV3Algorithm = ConvAlgorithmDescriptor && + SpecifiesDataTransfer && SpecifiesGridwiseFwdXdlGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + SpecifiesBlockGemm; -// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) +// Standard XDL-based fwd kernel (uses XDLops hardware instructions for matrix multiply) template -concept IsXdlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && - SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; +concept IsFwdXdlAlgorithm = ConvAlgorithmDescriptor && + SpecifiesDataTransfer && SpecifiesGridwiseFwdXdlGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && + SpecifiesLoopScheduler; + +// Standard XDL-based bwd weight kernel (uses XDLops hardware instructions for matrix multiply) +template +concept IsBwdXdlAlgorithm = ConvAlgorithmDescriptor && + SpecifiesDataTransfer && SpecifiesGridwiseBwdXdlGemm && + SpecifiesBwdWeightConvSpecialization && SpecifiesTransposeTransfer; // WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) template -concept IsWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; +concept IsFwdWmmaAlgorithm = ConvAlgorithmDescriptor && + SpecifiesDataTransfer && SpecifiesGridwiseWmmaGemm && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; // Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts template -concept IsDlAlgorithm = +concept IsFwdDlAlgorithm = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; @@ -139,19 +149,19 @@ constexpr auto make_conv_instance() } else if constexpr(ConvDirectionIsForward) { - if constexpr(IsXdlV3Algorithm) + if constexpr(IsFwdXdlV3Algorithm) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(IsXdlAlgorithm) + else if constexpr(IsFwdXdlAlgorithm) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(IsWmmaAlgorithm) + else if constexpr(IsFwdWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(IsDlAlgorithm) + else if constexpr(IsFwdDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } @@ -177,10 +187,17 @@ constexpr auto make_conv_instance() } else if constexpr(ConvDirectionIsBackwardWeight) { - static_assert( - false, - "Backward weight convolution is not yet supported. " - "Only forward convolution (ConvDirection::FORWARD) is currently implemented."); + if constexpr (IsBwdXdlAlgorithm) + { + return typename ConvBwdWeightXdlFactory::Instance{}; + } + else + { + static_assert( + false, + "Backward weight convolution is not yet supported. " + "Only forward convolution (ConvDirection::FORWARD) is currently implemented."); + } } else { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index ca202aabfd..42c59dfaec 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -24,7 +24,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index fadf41f48a..fca3638697 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 89787cc1b3..47891869cc 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index bb84479071..1fb3942df0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 8ec5c633ce..695f154614 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::FwdConvTensorDataTypes; using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp index a39cd7410b..b24344c90a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp @@ -62,6 +62,7 @@ consteval auto GetElementwiseOp() } template +requires ConvDirectionIsForward struct ElementwiseOps { static constexpr auto input_op = GetElementwiseOp(); @@ -72,4 +73,16 @@ struct ElementwiseOps using CDEElementwiseOp = typename decltype(output_op)::Op; }; +template +requires ConvDirectionIsBackwardWeight +struct ElementwiseOps +{ + static constexpr auto input_op = GetElementwiseOp(); + static constexpr auto weight_op = GetElementwiseOp(); + static constexpr auto output_op = GetElementwiseOp(); + using InElementwiseOp = typename decltype(input_op)::Op; + using WeiElementwiseOp = typename decltype(weight_op)::Op; + using OutElementwiseOp = typename decltype(output_op)::Op; +}; + } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index a6c0b48c54..3df3f8f37c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -216,18 +216,31 @@ consteval auto GetAuxiliaryTensorLayouts() return EmptyAuxiliaryTensorLayout{}; } -template +template requires(ConvSpatialDim && ValidConvInputLayoutForSpatialDim && ValidConvWeightLayoutForSpatialDim && - ValidConvOutputLayoutForSpatialDim) + ValidConvOutputLayoutForSpatialDim && + ConvDirectionIsForward) struct ConvTensorLayouts { - static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported."); using ALayout = decltype(TensorLayoutToCK()); using BLayout = decltype(TensorLayoutToCK()); using ELayout = decltype(TensorLayoutToCK()); using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; }; +template + requires(ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim && + ConvDirectionIsBackwardWeight) +struct ConvTensorLayouts +{ + using InLayout = decltype(TensorLayoutToCK()); + using WeiLayout = decltype(TensorLayoutToCK()); + using OutLayout = decltype(TensorLayoutToCK()); +}; + } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index c819e11d00..7839cb7f4a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -151,6 +151,7 @@ consteval auto GetAuxiliaryTensorDataTypes() } template +requires ConvDirectionIsForward struct FwdConvTensorDataTypes { static constexpr auto input_types = @@ -176,4 +177,24 @@ struct FwdConvTensorDataTypes using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes())::type; }; +template +requires ConvDirectionIsBackwardWeight +struct FwdConvTensorDataTypes +{ + static constexpr auto input_types = + GetTensorDataAndComputeTypes(); + static constexpr auto weight_types = + GetTensorDataAndComputeTypes(); + static constexpr auto output_types = + GetTensorDataAndComputeTypes(); + + using InDataType = typename decltype(input_types.first)::type; + using InComputeType = typename decltype(input_types.second)::type; + using WeiDataType = typename decltype(weight_types.first)::type; + using WeiComputeType = typename decltype(weight_types.second)::type; + using AccDataType = + typename decltype(GetTensorAccumulationType())::type; +}; + } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index db741f2112..6f3a9e8e78 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -149,12 +149,27 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; switch(specialization) { - case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default; - case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; - case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; - case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; - case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC; - default: throw "Unknown ConvFwdSpecialization"; + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; + case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + default: throw "Unsupported ConvSpecialization"; + } +} + +template +consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization SetBwdWeightConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.bwd_specialization; + using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + switch(specialization) + { + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + default: throw "Unsupported ConvSpecialization"; } } diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index f7386720b3..5f08b5ab9c 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -192,8 +192,8 @@ enum class TileConvSpecialization FILTER_3x3 }; -// Enums for the forward convolution specialization. -enum class ConvFwdSpecialization +// Enums for the convolution specializations. +enum class ConvSpecialization { DEFAULT, FILTER_1X1_PAD0, diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index 7ae0ba27ea..87de6dab03 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -20,7 +20,7 @@ constexpr auto SIGNATURE = .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; -constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{} .with_thread_block(cku::FwdThreadBlock_256_256x256x32) .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::FwdTransfer_4x64x1) @@ -34,7 +34,7 @@ using Instance = Builder::Instance; TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); - cku::run_test({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3", + cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffle", expected_transfer_parameters, "Default", "Intrawave", diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 29c7f3cdcc..6b87ae77d6 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -28,18 +28,30 @@ struct ThreadBlock }; static_assert(ckb::ThreadBlockDescriptor); -// Describe gridwise XDL GEMM parameters. -struct GridwiseXdlGemm +struct XdlParams { - // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! - size_t ak1 = 0; - size_t bk1 = 0; size_t m_per_xdl = 0; size_t n_per_xdl = 0; size_t m_xdl_per_wave = 0; size_t n_xdl_per_wave = 0; }; -static_assert(ckb::GridwiseXdlGemmDescriptor); +static_assert(ckb::GridwiseXdlGemmDescriptor); + +// Describe gridwise XDL GEMM parameters. +struct GridwiseFwdXdlGemm : public XdlParams +{ + // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! + size_t ak1 = 0; + size_t bk1 = 0; +}; +static_assert(ckb::SpecifiesGridwiseFwdXdlGemm); + +struct GridwiseBwdXdlGemm : public XdlParams +{ + size_t k0_per_block = 0; + size_t k1 = 0; +}; +static_assert(ckb::SpecifiesGridwiseBwdXdlGemm); // Describe gridwise WMMA GEMM parameters. struct GridwiseWmmaGemm @@ -169,9 +181,14 @@ struct ThreadBlock_ ThreadBlock thread_block; }; -struct XdlGemm_ +struct FwdXdlGemm_ { - GridwiseXdlGemm gridwise_gemm; + GridwiseFwdXdlGemm gridwise_gemm; +}; + +struct BwdXdlGemm_ +{ + GridwiseBwdXdlGemm gridwise_gemm; }; struct WmmaGemm_ @@ -184,12 +201,17 @@ struct Transfer_ TransferABC transfer; }; -struct ConvSpecialization_ +struct ConvSpecializationFwd_ { - ConvFwdSpecialization fwd_specialization; + ConvSpecialization fwd_specialization; GemmSpecialization gemm_specialization; }; +struct ConvSpecializationBwdWeight_ +{ + ConvSpecialization bwd_specialization; +}; + struct Prefetch_ { size_t num_gemm_k_prefetch_stages; @@ -197,6 +219,12 @@ struct Prefetch_ PipelineScheduler loop_scheduler; }; +struct TransposeParams_ +{ + size_t max_transpose_transfer_src_scalar_per_vector{1}; + size_t max_transpose_transfer_dst_scalar_per_vector{1}; +}; + struct BlockGemm_ { BlockGemm block_gemm; @@ -329,7 +357,11 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_config(const GemmConfig& gemm) const { auto result = *this; - if constexpr(std::is_base_of_v) + if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } + if constexpr(std::is_base_of_v) { result.gridwise_gemm = gemm; } @@ -359,6 +391,14 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_specializations(ConvBwdWeightSpecialization bwd_spec) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.bwd_specialization = bwd_spec; + return result; + } + constexpr auto with_prefetch_config(size_t k_prefetch_stages, size_t groups_to_merge, PipelineScheduler scheduler) const @@ -371,6 +411,16 @@ struct ConvAlgorithmTemplate : Components... return result; } + constexpr auto with_transpose_params(bool max_src_scalar_per_vector, + bool max_dst_scalar_per_vector) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector; + result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector; + return result; + } + template constexpr auto with_block_gemm(const BG& bg) const { @@ -456,16 +506,17 @@ struct ConvAlgorithmTemplate : Components... // Algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate; + using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate; @@ -479,4 +530,7 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + ConvAlgorithmTemplate; + } // namespace ck_tile::builder::test