diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index cdbe805cdd..447bbdad5e 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -65,7 +65,7 @@ concept BlockTransferDescriptor = requires(T t) { }; template -concept BlockTransferDescriptorBwd = requires(T t) { +concept BlockTransferDescriptor4D = requires(T t) { { t.k0 } -> SizeType; { t.m_n } -> SizeType; { t.k1 } -> SizeType; @@ -211,11 +211,12 @@ concept SpecifiesBlockTransfer = requires(T t) { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; -// Concept to check if a struct specifies convolution input and output block transfer info (Bwd direction). +// Concept to check if a struct specifies convolution input and output block transfer info +// for 4D thread slices. template -concept SpecifiesBlockTransferBwd = requires(T t) { - { T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd; - { T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd; +concept SpecifiesBlockTransfer4D = requires(T t) { + { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; + { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp index 2340e19e61..6613d2d736 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_diagnostics.hpp @@ -184,9 +184,9 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string { return msg; } -// BlockTransferDescriptorBwd diagnostics (requires k_batch_size) +// BlockTransferDescriptor4D diagnostics (requires k_batch_size) template -consteval auto diagnose_block_transfer_bwd(const char* prefix) -> std::string { +consteval auto diagnose_block_transfer_4d(const char* prefix) -> std::string { std::string msg; if constexpr (requires(BT t) { t.k0; }) { @@ -500,7 +500,7 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string { } template -consteval auto detailed_diagnostic_SpecifiesBlockTransferBwd() -> std::string { +consteval auto detailed_diagnostic_SpecifiesBlockTransfer4D() -> std::string { std::string msg; constexpr bool has_transfer = requires { T::transfer; }; @@ -510,16 +510,20 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransferBwd() -> std::string { return msg; } - constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd; }; + constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; }; msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n"; if constexpr (!has_a) { msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n"; + } else { + msg += detail::diagnose_block_transfer_4d("transfer.a.block_transfer"); } - constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd; }; + constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; }; msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n"; if constexpr (!has_b) { msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n"; + } else { + msg += detail::diagnose_block_transfer_4d("transfer.b.block_transfer"); } constexpr bool has_c = requires { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index e537b7ba99..bf7f0248fd 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -242,7 +242,7 @@ template struct BwdXdlAlgorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransferBwd) + CHECK_CONCEPT(T, SpecifiesBlockTransfer4D) CHECK_CONCEPT(T, SpecifiesLdsTransfer) CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) @@ -252,7 +252,7 @@ struct BwdXdlAlgorithm { static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransferBwd; + static constexpr bool c3 = c_SpecifiesBlockTransfer4D; static constexpr bool c4 = c_SpecifiesLdsTransfer; static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; static constexpr bool c6 = c_SpecifiesSourceAccessOrder; @@ -269,7 +269,7 @@ struct BwdXdlAlgorithm { "Concepts for BwdXdl Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + @@ -283,7 +283,7 @@ template struct BwdXdlV3Algorithm { CHECK_CONCEPT(T, ConvAlgorithmDescriptor) CHECK_CONCEPT(T, SpecifiesThreadBlock) - CHECK_CONCEPT(T, SpecifiesBlockTransferBwd) + CHECK_CONCEPT(T, SpecifiesBlockTransfer) CHECK_CONCEPT(T, SpecifiesLdsTransfer) CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder) CHECK_CONCEPT(T, SpecifiesSourceAccessOrder) @@ -293,7 +293,7 @@ struct BwdXdlV3Algorithm { static constexpr bool c1 = c_ConvAlgorithmDescriptor; static constexpr bool c2 = c_SpecifiesThreadBlock; - static constexpr bool c3 = c_SpecifiesBlockTransferBwd; + static constexpr bool c3 = c_SpecifiesBlockTransfer; static constexpr bool c4 = c_SpecifiesLdsTransfer; static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder; static constexpr bool c6 = c_SpecifiesSourceAccessOrder; @@ -310,7 +310,7 @@ struct BwdXdlV3Algorithm { "Concepts for BwdXdlV3 Algorithm:\n") + DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) + DIAGNOSTIC_LINE(SpecifiesThreadBlock) + - DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) + + DIAGNOSTIC_LINE(SpecifiesBlockTransfer) + DIAGNOSTIC_LINE(SpecifiesLdsTransfer) + DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) + DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) + diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 25cf773694..69facce41b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -12,9 +12,9 @@ namespace ck_tile::builder::factory::internal { // Block transfer parameters for A or B tensor. struct BlockTransfer { - ck::Array thread_cluster_dims = {0, 0, 0}; // k0, m, k1 - ck::Array thread_cluster_order = {0, 0, 0}; - ck::Array src_access_order = {0, 0, 0}; + ck::Array thread_cluster_dims{}; // k0, m, k1 + ck::Array thread_cluster_order{}; + ck::Array src_access_order{}; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; @@ -22,15 +22,15 @@ struct BlockTransfer bool lds_padding = false; }; +template struct BwdBlockTransfer { - ck::Array thread_cluster_dims = {0, 0, 0, 0}; - ck::Array thread_cluster_order = {0, 0, 0, 0}; - ck::Array src_access_order = {0, 0, 0, 0}; + ck::Array thread_cluster_dims{}; + ck::Array thread_cluster_order{}; + ck::Array src_access_order{}; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; - bool is_direct_load = false; bool lds_padding = false; }; @@ -55,7 +55,7 @@ constexpr BlockTransfer SetFwdConvBlockTransfer() } template -constexpr BwdBlockTransfer SetBwdConvBlockTransfer() +constexpr auto SetBwdConvBlockTransfer() { auto& block_xfer = TRANSFER.block_transfer; auto& block_order = TRANSFER.block_transfer_access_order; @@ -68,27 +68,25 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer() if constexpr (array_length == 3) { - return BwdBlockTransfer{ - .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + return BwdBlockTransfer<3>{ + .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .is_direct_load = lds_cfg.is_direct_load, .lds_padding = lds_cfg.lds_padding, }; } else if constexpr (array_length == 4) { - return BwdBlockTransfer{ + return BwdBlockTransfer<4>{ .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]}, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, - .is_direct_load = lds_cfg.is_direct_load, .lds_padding = lds_cfg.lds_padding, }; }