diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 108ccc0425..4d81becfb5 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -51,12 +51,24 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { { t.pipeline_version } -> std::convertible_to; }; + +template +concept HasGemmKBatch = requires(T t) { + { t.k_batch_size}; +}; + +// Concept to check if GEMM k batch size is specified. +template +concept GemmKBatchSizeWellDefinedIfProvided = + !HasGemmKBatch || requires(T t) { {t.k_batch_size} -> std::convertible_to; }; + // Concept for vectorized data transfer for convolution input tensors. template concept BlockTransferDescriptor = requires(T t) { { t.k0 } -> std::convertible_to; { t.m_n } -> std::convertible_to; { t.k1 } -> std::convertible_to; + GemmKBatchSizeWellDefinedIfProvided; }; // Concept for thread cluster dimensions for GEMM output tensor. @@ -91,6 +103,8 @@ concept EpilogueDescriptor = requires(T t) { template concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; +} || requires(T t) { + { t.order } -> std::convertible_to>; }; // Concept for thread block dimensions for a GEMM problem for CK Tile (Block @@ -166,7 +180,6 @@ concept GridwiseFwdXdlGemmDescriptor = requires (T t){ // Concept to check if a struct specifies gridwise XDL GEMM info. template concept GridwiseBwdXdlGemmDescriptor = requires (T t){ - { t.k0_per_block } -> std::convertible_to; { t.k1 } -> std::convertible_to; { t.xdl_params } -> GridwiseXdlGemmDescriptor; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index f97ef7c275..3b75eb00e9 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -181,6 +181,14 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string { msg += std::string(" → ") + prefix + ".k1: [✗] (missing member)\n"; } + // k_batch_size is optional - only report if it exists and has wrong type + if constexpr (requires(BT t) { t.k_batch_size; }) { + using KBatchType = decltype(std::declval().k_batch_size); + constexpr bool convertible = std::convertible_to; + msg += std::string(" → ") + prefix + ".k_batch_size (optional): " + std::string(CHECK_MARK(convertible)) + + std::string(get_type_info()) + "\n"; + } + return msg; } @@ -288,7 +296,9 @@ consteval auto diagnose_access_order(const char* prefix) -> std::string { if constexpr (requires(AO t) { t.order; }) { using OrderType = decltype(std::declval().order); - constexpr bool convertible = std::convertible_to>; + constexpr bool convertible_3 = std::convertible_to>; + constexpr bool convertible_4 = std::convertible_to>; + constexpr bool convertible = convertible_3 || convertible_4; msg += std::string(" → ") + prefix + ".order: " + std::string(CHECK_MARK(convertible)) + std::string(get_type_info()) + "\n"; } else { @@ -401,15 +411,6 @@ consteval auto detailed_diagnostic_SpecifiesGridwiseBwdXdlGemm() -> std::string msg += " → T::gridwise_gemm member: [✓]\n"; using GG = decltype(T::gridwise_gemm); - if constexpr (requires(GG t) { t.k0_per_block; }) { - using K0Type = decltype(std::declval().k0_per_block); - constexpr bool convertible = std::convertible_to; - msg += " → gridwise_gemm.k0_per_block: " + std::string(CHECK_MARK(convertible)) + - std::string(detail::get_type_info()) + "\n"; - } else { - msg += " → gridwise_gemm.k0_per_block: [✗] (missing member)\n"; - } - if constexpr (requires(GG t) { t.k1; }) { using K1Type = decltype(std::declval().k1); constexpr bool convertible = std::convertible_to; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index cc3262c07c..8cb0d5d2ae 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -69,7 +69,7 @@ struct ConvBwdWeightXdlFactory BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, - GRIDWISE_GEMM.k0_per_block, + BLOCK.per_block.k, GRIDWISE_GEMM.k1, XDL_PARAMS.m_per_xdl, XDL_PARAMS.n_per_xdl, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 9729a72ce7..5ee07fa2d1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -63,9 +63,9 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer() auto& lds_cfg = TRANSFER.lds_transfer; return BwdBlockTransfer{ - .thread_cluster_dims = {1, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, - .thread_cluster_order = {0, block_order.order[0], block_order.order[1], block_order.order[2]}, - .src_access_order = {0, src_order.order[0], src_order.order[1], src_order.order[2]}, + .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]}, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp index f626bbb288..9989a20ad0 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight.cpp @@ -21,7 +21,7 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{} - .with_thread_block(cku::ThreadBlock_256_128x128x32) + .with_thread_block(cku::ThreadBlock_256_128x128x8) .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) .with_transfer(cku::BwdTransfer_4x64x1) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); @@ -35,11 +35,8 @@ TEST(BwdWeight_2DFp16_CShufV3_GNHWC, Create) cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffle", expected_transfer_parameters, "Default", - "Intrawave", - "v3", - "GNHWC,GKYXC,EmptyTuple,GNHWK", - "PassThrough,PassThrough,PassThrough", - "MNKPadding"}); + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough"}); } // TEST(BwdWeight_2DFp16_CShufV3_GNHWC, EndToEnd) diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 4d5ac2cd9e..2d2829820d 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -49,7 +49,6 @@ static_assert(ckb::GridwiseFwdXdlGemmDescriptor); struct GridwiseBwdXdlGemm { - size_t k0_per_block = 0; size_t k1 = 0; XdlParams xdl_params; }; @@ -75,13 +74,25 @@ struct BlockGemm static_assert(ckb::BlockGemmDescriptor); // Describe Aand B block transfer thread cluster lengths. +template struct BlockTransfer { size_t k0; size_t m_n; size_t k1; + size_t k_batch_size; }; -static_assert(ckb::BlockTransferDescriptor); + +// Specialization for forward (IsBwd = false) +template <> +struct BlockTransfer +{ + size_t k0; + size_t m_n; + size_t k1; +}; +static_assert(ckb::BlockTransferDescriptor>); +static_assert(ckb::BlockTransferDescriptor>); // Describe C block transfer thread cluster lengths. struct ThreadCluster @@ -111,31 +122,35 @@ struct Epilogue }; static_assert(EpilogueDescriptor); +template struct AccessOrder { - std::array order; + std::array order; }; -static_assert(AccessOrderDescriptor); +static_assert(AccessOrderDescriptor>); +static_assert(AccessOrderDescriptor>); -struct TransferAB +template +struct InputTransfer { - BlockTransfer block_transfer; + BlockTransfer block_transfer; LdsTransfer lds_transfer; - AccessOrder block_transfer_access_order; - AccessOrder src_access_order; + std::conditional_t, AccessOrder<3>> block_transfer_access_order; + std::conditional_t, AccessOrder<3>> src_access_order; }; -struct TransferC +struct OutputTransfer { ThreadCluster thread_cluster_dims; Epilogue epilogue; }; -struct TransferABC +template +struct Transfer { - TransferAB a; - TransferAB b; - TransferC c; + InputTransfer a; + InputTransfer b; + OutputTransfer c; }; // DL-specific descriptors @@ -198,9 +213,10 @@ struct WmmaGemm_ GridwiseWmmaGemm gridwise_gemm; }; +template struct Transfer_ { - TransferABC transfer; + Transfer transfer; }; struct ConvSpecializationFwd_ @@ -380,7 +396,8 @@ struct ConvAlgorithmTemplate : Components... template constexpr auto with_transfer(const T& t) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; @@ -511,13 +528,13 @@ struct ConvAlgorithmTemplate : Components... // Algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, ConvSpecializationFwd_, BlockGemm_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, ConvSpecializationFwd_, Prefetch_>; using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, ConvSpecializationBwdWeight_, TransposeParams_>; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index d176506526..956f65f453 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -39,7 +39,7 @@ constexpr DlTransferABC DlFwdTransfer{.a = .dst_scalar_per_vector = 4}, }}; -constexpr TransferABC Transfer_4x64x1{ +constexpr Transfer<> Transfer_4x64x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, @@ -72,28 +72,29 @@ constexpr TransferABC Transfer_4x64x1{ }, }; -constexpr TransferABC BwdTransfer_4x64x1{ +constexpr bool BWD = true; +constexpr Transfer BwdTransfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, .lds_transfer = {.src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {3, 1, 2}, - .src_access_order = {2, 1, 3}, + .block_transfer_access_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, .lds_transfer = {.src_vector_dim = 2, .src_scalar_per_vector = 2, .lds_dst_scalar_per_vector = 4, .is_direct_load = false, .lds_padding = true}, - .block_transfer_access_order = {3, 1, 2}, - .src_access_order = {2, 1, 3}, + .block_transfer_access_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, }, .c = { @@ -105,7 +106,7 @@ constexpr TransferABC BwdTransfer_4x64x1{ }, }; -constexpr TransferABC Transfer_4x64x1_fp8{ +constexpr Transfer<> Transfer_4x64x1_fp8{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, @@ -138,7 +139,7 @@ constexpr TransferABC Transfer_4x64x1_fp8{ }, }; -constexpr TransferABC Transfer_4x16x1{ +constexpr Transfer<> Transfer_4x16x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, @@ -172,7 +173,7 @@ constexpr TransferABC Transfer_4x16x1{ }, }; -constexpr TransferABC Transfer_4x32x1{ +constexpr Transfer<> Transfer_4x32x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, @@ -206,7 +207,7 @@ constexpr TransferABC Transfer_4x32x1{ }; constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ - .k0_per_block = 8, .k1 = 8, + .k1 = 8, .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ @@ -244,6 +245,9 @@ constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256, constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 16}}; +constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 8}}; + constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64, .tile_size = {.m = 64, .n = 32, .k = 32}}; diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index c3afe2bd4e..f7096f27f8 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -89,7 +89,7 @@ template <> inline std::string to_string(GridwiseBwdXdlGemm t) { std::ostringstream oss; - oss << t.k0_per_block << "," << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," + oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; return oss.str(); } @@ -120,10 +120,17 @@ inline std::string to_string(BlockGemm t) return oss.str(); } -template <> -inline std::string to_string(BlockTransfer t) +template +inline std::string to_string(BlockTransfer t) { - return array_to_seq(std::array{t.k0, t.m_n, t.k1}); + if constexpr (IsBwd) + { + return array_to_seq(std::array{t.k_batch_size, t.k0, t.m_n, t.k1}); + } + else + { + return array_to_seq(std::array{t.k0, t.m_n, t.k1}); + } } template <> @@ -143,14 +150,14 @@ inline std::string to_string(LdsTransfer t) return oss.str(); } -template <> -inline std::string to_string(AccessOrder t) +template +inline std::string to_string(AccessOrder t) { return array_to_seq(t.order); } -template <> -inline std::string to_string(TransferAB t) +template +inline std::string to_string(InputTransfer t) { std::ostringstream oss; oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," @@ -161,7 +168,7 @@ inline std::string to_string(TransferAB t) } template <> -inline std::string to_string(TransferC t) +inline std::string to_string(OutputTransfer t) { std::ostringstream oss; oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," @@ -169,8 +176,8 @@ inline std::string to_string(TransferC t) return oss.str(); } -template <> -inline std::string to_string(TransferABC t) +template +inline std::string to_string(Transfer t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -260,8 +267,8 @@ inline std::string to_string(WmmaGemm_ t) return to_string(t.gridwise_gemm); } -template <> -inline std::string to_string(Transfer_ t) +template +inline std::string to_string(Transfer_ t) { return to_string(t.transfer); } @@ -323,7 +330,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -333,7 +340,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -343,7 +350,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -371,8 +378,9 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); }