diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index f60e7703a3..d35897fc78 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -29,18 +29,20 @@ concept OutputVectorTransferLimits = requires { // Limits for access order. Must be a permutation of {0, 1, 2}. template -concept AccessOrderLimits = requires { +concept AccessOrderLimits3D = requires { requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) && (Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) && - (Value[2] >= 0 && Value[2] < 3)); + (Value[2] >= 0 && Value[2] < 3) && (Value.Size() == 3)); }; -// Limits for access order. Must be a permutation of {1, 2, 3} for the last three elements. +// Limits for access order. Must be a permutation of {0, 1, 2, 3}. template -concept BwdAccessOrderLimits = requires { - requires((Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) && - (Value[1] >= 1 && Value[1] < 4) && (Value[2] >= 1 && Value[2] < 4) && - (Value[3] >= 1 && Value[3] < 4)) && (Value[0] == 0); +concept AccessOrderLimits4D = requires { + requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[0] != Value[3]) && + (Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) && + (Value[0] >= 0 && Value[0] < 4) && (Value[1] >= 0 && Value[1] < 4) && + (Value[2] >= 0 && Value[2] < 4) && (Value[3] >= 0 && Value[3] < 4) && + (Value.Size() == 4)); }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 71e3d03da3..e537b7ba99 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -279,6 +279,47 @@ struct BwdXdlAlgorithm { } }; +template +struct BwdXdlV3Algorithm { + CHECK_CONCEPT(T, ConvAlgorithmDescriptor) + CHECK_CONCEPT(T, SpecifiesThreadBlock) + CHECK_CONCEPT(T, SpecifiesBlockTransferBwd) + CHECK_CONCEPT(T, SpecifiesLdsTransfer) + CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) + CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) + CHECK_CONCEPT(T, SpecifiesGridwiseBwdXdlGemm) + CHECK_CONCEPT(T, SpecifiesBwdWeightConvSpecialization) + CHECK_CONCEPT(T, SpecifiesBlockGemm) + + static constexpr bool c1 = c_ConvAlgorithmDescriptor; + static constexpr bool c2 = c_SpecifiesThreadBlock; + static constexpr bool c3 = c_SpecifiesBlockTransferBwd; + static constexpr bool c4 = c_SpecifiesLdsTransfer; + static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; + static constexpr bool c6 = c_SpecifiesSourceAccessOrder; + static constexpr bool c7 = c_SpecifiesGridwiseBwdXdlGemm; + static constexpr bool c8 = c_SpecifiesBwdWeightConvSpecialization; + static constexpr bool c9 = c_SpecifiesBlockGemm; + + static consteval bool is_valid() { + return c1 && c2 && c3 && c4 && c5 && c6 && c7 && c8 && c9; + } + + static consteval auto message() -> std::string { + return std::string("\n=== Backward XDL V3 Algorithm Diagnostic (closest match) ===\n" + "Concepts for BwdXdlV3 Algorithm:\n") + + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + + DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) + + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + + DIAGNOSTIC_LINE(SpecifiesGridwiseBwdXdlGemm) + + DIAGNOSTIC_LINE(SpecifiesBwdWeightConvSpecialization) + + DIAGNOSTIC_LINE(SpecifiesBlockGemm); + } +}; + template consteval int count_matches_fwd_xdl_v3() { using Alg = FwdXdlV3Algorithm; @@ -309,6 +350,12 @@ consteval int count_matches_bwd_xdl() { return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; } +template +consteval int count_matches_bwd_xdl_v3() { + using Alg = BwdXdlV3Algorithm; + return Alg::c1 + Alg::c2 + Alg::c3 + Alg::c4 + Alg::c5 + Alg::c6 + Alg::c7 + Alg::c8 + Alg::c9; +} + template consteval int count_matches_large_tensor() { using Alg = LargeTensorAlgorithm; @@ -368,12 +415,20 @@ consteval void diagnose_fwd_algorithm_signature() template consteval void diagnose_bwd_weight_algorithm_signature() { - constexpr int xdl_matches = count_matches_fwd_xdl(); - constexpr int max_matches = xdl_matches; + constexpr int xdl_matches = count_matches_bwd_xdl(); + constexpr int xdl_v3_matches = count_matches_fwd_xdl_v3(); + + constexpr int max_matches = xdl_v3_matches > xdl_matches ? xdl_v3_matches : xdl_matches; + if constexpr (max_matches == xdl_matches) { using Alg = BwdXdlAlgorithm; static_assert(Alg::is_valid(), Alg::message()); - } else { + } + else if constexpr (max_matches == xdl_v3_matches) { + using Alg = BwdXdlV3Algorithm; + static_assert(Alg::is_valid(), Alg::message()); + } + else { // This should never happen static_assert(false, "Internal Error: No matching algorithm variant found for diagnostics."); } diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 6ad5820dab..8790121ed9 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -47,10 +47,10 @@ struct ConvBwdWeightXdlFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(BwdAccessOrderLimits); - static_assert(BwdAccessOrderLimits); - static_assert(BwdAccessOrderLimits); - static_assert(BwdAccessOrderLimits); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp new file mode 100644 index 0000000000..14121be940 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -0,0 +1,103 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightXdlV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::BwdWeightConvTensorDataTypes; + using Ops = internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, "Invalid A source access order"); + static_assert(AccessOrderLimits3D, "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::InComputeType, + typename Types::WeiComputeType>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index adbf12992e..01c0fb9c56 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -49,6 +49,12 @@ #pragma once +// Disable pragma message warnings for factory selection diagnostics +#ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-W#pragma-messages" +#endif + #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/types.hpp" @@ -64,6 +70,7 @@ #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" #include "ck_tile/builder/factory/conv_tile_factory.hpp" #include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp" namespace ck_tile::builder::factory { @@ -96,28 +103,34 @@ constexpr auto make_conv_instance() // CK Tile supports common factory for each direction if constexpr(TileAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvTileFactory...") return typename ConvTileFactory::Instance{}; } else if constexpr(ConvDirectionIsForward) { if constexpr(FwdXdlV3Algorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdXdlV3Factory...") return typename ConvFwdXdlV3Factory::Instance{}; } else if constexpr(FwdXdlAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdXdlFactory...") return typename ConvFwdXdlFactory::Instance{}; } else if constexpr(FwdWmmaAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdWmmaFactory...") return typename ConvFwdWmmaFactory::Instance{}; } else if constexpr(FwdDlAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdDlFactory...") return typename ConvFwdDlFactory::Instance{}; } else if constexpr(LargeTensorAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvFwdLargeTensorFactory...") return typename ConvFwdLargeTensorFactory::Instance{}; } else @@ -135,8 +148,14 @@ constexpr auto make_conv_instance() { if constexpr (BwdXdlAlgorithm::is_valid()) { + #pragma message("[CK Builder] Using ConvBwdWeightXdlFactory...") return typename ConvBwdWeightXdlFactory::Instance{}; } + else if constexpr (BwdXdlV3Algorithm::is_valid()) + { + #pragma message("[CK Builder] Using ConvBwdWeightXdlV3Factory...") + return typename ConvBwdWeightXdlV3Factory::Instance{}; + } else { diagnose_bwd_weight_algorithm_signature(); @@ -152,3 +171,8 @@ constexpr auto make_conv_instance() } } // namespace ck_tile::builder::factory + +// Re-enable pragma message warnings +#ifdef __clang__ + #pragma clang diagnostic pop +#endif diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 62547dbe32..456c567aa0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -54,10 +54,10 @@ struct ConvFwdLargeTensorFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance with large tensor support. using Instance = diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 30fd555dd9..e34f39965f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -56,10 +56,10 @@ struct ConvFwdXdlV3Factory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 1fb3942df0..dbaa8651eb 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -52,10 +52,10 @@ struct ConvFwdWmmaFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 16baf4fbce..cebf5a0c3a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -51,10 +51,10 @@ struct ConvFwdXdlFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 5ee07fa2d1..25cf773694 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -62,16 +62,40 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer() auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; - return BwdBlockTransfer{ - .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]}, - .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]}, - .src_vector_dim = lds_cfg.src_vector_dim, - .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, - .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .is_direct_load = lds_cfg.is_direct_load, - .lds_padding = lds_cfg.lds_padding, - }; + constexpr auto array_length = block_order.order.size(); + static_assert(block_order.order.size() == src_order.order.size(), + "Mismatched size between block order and src order"); + + if constexpr (array_length == 3) + { + return BwdBlockTransfer{ + .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .is_direct_load = lds_cfg.is_direct_load, + .lds_padding = lds_cfg.lds_padding, + }; + } + else if constexpr (array_length == 4) + { + return BwdBlockTransfer{ + .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .is_direct_load = lds_cfg.is_direct_load, + .lds_padding = lds_cfg.lds_padding, + }; + } + else + { + static_assert(false, "Internal error: Unsupported array length"); + } } // Block transfer parameters for C tensor. diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 8ffdf3f543..8105a41bf5 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -142,6 +142,7 @@ target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) add_ck_builder_test(test_ckb_build_bwd_weight_instances conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp ) target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp new file mode 100644 index 0000000000..2dfa6e5771 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 1, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_1DBf16_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3", + expected_transfer_parameters, + "FILTER_1X1_STRIDE1_PAD0", + "NGCW,GKXC,NGKW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v2"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index c87ffbd066..22911a1a26 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, .accumulation_data_type = INT32, .input = {.config = {.layout = GNWC}}, .weight = {.config = {.layout = GKXC}}, - .output = {.config = {.layout = GNWK}}}; + .output = {.config = {.layout = GNWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index e3c947f541..b045d185e2 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -74,7 +74,7 @@ struct BlockGemm static_assert(ckb::BlockGemmDescriptor); // Describe Aand B block transfer thread cluster lengths. -template +template struct BlockTransfer { size_t k0; @@ -83,16 +83,16 @@ struct BlockTransfer size_t k_batch_size; }; -// Specialization for forward (IsBwd = false) +// Specialization for ThreadSliceLength == 3 template <> -struct BlockTransfer +struct BlockTransfer<3> { size_t k0; size_t m_n; size_t k1; }; static_assert(ckb::BlockTransferDescriptor>); -static_assert(ckb::BlockTransferDescriptor>); +static_assert(ckb::BlockTransferDescriptor>); // Describe C block transfer thread cluster lengths. struct ThreadCluster @@ -130,13 +130,13 @@ struct AccessOrder static_assert(AccessOrderDescriptor>); static_assert(AccessOrderDescriptor>); -template +template struct InputTransfer { - BlockTransfer block_transfer; + BlockTransfer block_transfer; LdsTransfer lds_transfer; - std::conditional_t, AccessOrder<3>> block_transfer_access_order; - std::conditional_t, AccessOrder<3>> src_access_order; + AccessOrder block_transfer_access_order; + AccessOrder src_access_order; }; struct OutputTransfer @@ -145,11 +145,11 @@ struct OutputTransfer Epilogue epilogue; }; -template +template struct Transfer { - InputTransfer a; - InputTransfer b; + InputTransfer a; + InputTransfer b; OutputTransfer c; }; @@ -213,10 +213,10 @@ struct WmmaGemm_ GridwiseWmmaGemm gridwise_gemm; }; -template +template struct Transfer_ { - Transfer transfer; + Transfer transfer; }; struct ConvSpecializationFwd_ @@ -397,7 +397,7 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_transfer(const T& t) const { static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || - std::is_base_of_v, ConvAlgorithmTemplate>); + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; @@ -553,6 +553,9 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = - ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, BlockGemm_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 956f65f453..7b5807ef23 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -72,8 +72,7 @@ constexpr Transfer<> Transfer_4x64x1{ }, }; -constexpr bool BWD = true; -constexpr Transfer BwdTransfer_4x64x1{ +constexpr Transfer<4> BwdTransfer_4x64x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, @@ -106,6 +105,39 @@ constexpr Transfer BwdTransfer_4x64x1{ }, }; +constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 2}, + }, +}; + constexpr Transfer<> Transfer_4x64x1_fp8{ .a = { @@ -210,6 +242,10 @@ constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ .k1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; +constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{ + .k1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}}; + constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ .ak1 = 8, .bk1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; @@ -251,6 +287,9 @@ constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256, constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64, .tile_size = {.m = 64, .n = 32, .k = 32}}; +constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64, + .tile_size = {.m = 32, .n = 32, .k = 32}}; + constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index f7096f27f8..cf13f39391 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -120,17 +120,21 @@ inline std::string to_string(BlockGemm t) return oss.str(); } -template -inline std::string to_string(BlockTransfer t) +template +inline std::string to_string(BlockTransfer t) { - if constexpr (IsBwd) + if constexpr (ThreadSliceDim == 4) { return array_to_seq(std::array{t.k_batch_size, t.k0, t.m_n, t.k1}); } - else + else if constexpr (ThreadSliceDim == 3) { return array_to_seq(std::array{t.k0, t.m_n, t.k1}); } + else + { + static_assert(ThreadSliceDim == 3 || ThreadSliceDim == 4, "Unsupported ThreadSliceDim"); + } } template <> @@ -156,8 +160,8 @@ inline std::string to_string(AccessOrder t) return array_to_seq(t.order); } -template -inline std::string to_string(InputTransfer t) +template +inline std::string to_string(InputTransfer t) { std::ostringstream oss; oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," @@ -176,8 +180,8 @@ inline std::string to_string(OutputTransfer t) return oss.str(); } -template -inline std::string to_string(Transfer t) +template +inline std::string to_string(Transfer t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -267,8 +271,8 @@ inline std::string to_string(WmmaGemm_ t) return to_string(t.gridwise_gemm); } -template -inline std::string to_string(Transfer_ t) +template +inline std::string to_string(Transfer_ t) { return to_string(t.transfer); } @@ -378,9 +382,18 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); }