diff --git a/experimental/builder/include/ck_tile/builder/builder_utils.hpp b/experimental/builder/include/ck_tile/builder/builder_utils.hpp index 5b4981c630..f16d96bec6 100644 --- a/experimental/builder/include/ck_tile/builder/builder_utils.hpp +++ b/experimental/builder/include/ck_tile/builder/builder_utils.hpp @@ -78,66 +78,4 @@ struct UnsupportedEnumValue { }; -// Helper functions to convert enums to strings -constexpr std::string_view ConvDirectionToString(ConvDirection dir) -{ - switch(dir) - { - case ConvDirection::FORWARD: return "Forward"; - case ConvDirection::BACKWARD_DATA: return "Backward Data"; - case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight"; - default: return "Unknown"; - } -} - -constexpr std::string_view DataTypeToString(DataType dt) -{ - switch(dt) - { - case DataType::FP16: return "FP16"; - case DataType::FP32: return "FP32"; - case DataType::BF16: return "BF16"; - case DataType::FP8: return "FP8"; - case DataType::I8: return "I8"; - case DataType::U8: return "U8"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout1D layout) -{ - switch(layout) - { - case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK"; - case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK"; - case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW"; - case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout2D layout) -{ - switch(layout) - { - case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK"; - case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK"; - case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW"; - case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout3D layout) -{ - switch(layout) - { - case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK"; - case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK"; - case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW"; - case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW"; - default: return "Unknown"; - } -} - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp index f016a342d3..3869c7b538 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp @@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); // Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); // Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); @@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward = // Predicate for DeviceGroupedConvBwdWeight operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); // Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); // Predicate for DeviceGroupedConvBwdWeightMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); // Predicate for DeviceGroupedConvBwdWeight_Dl operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); @@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight = // Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); // Predicate for DeviceGroupedConvBwdDataMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); // Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp new file mode 100644 index 0000000000..0b58f5a3b7 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +/// @file conv_description.hpp +/// @brief Provides human-readable descriptions of ConvBuilder configurations + +namespace ck_tile::reflect::conv { + +struct ConvSignatureInfo +{ + int spatial_dim; + builder::ConvDirection direction; + std::variant + layout; + builder::DataType data_type; + builder::ElementwiseOperation input_element_op; + builder::ElementwiseOperation weight_element_op; + builder::ElementwiseOperation output_element_op; +}; + +// Algorithm information - groups all algorithm-related configuration +struct GemmAlgorithmInfo +{ + int thread_block_size; + DataTileInfo tile_dims; + WarpGemmParams warp_gemm; + InputTileTransferInfo a_tile_transfer; + InputTileTransferInfo b_tile_transfer; + OutputTileTransferInfo c_tile_transfer; + builder::PipelineVersion pipeline_version; + builder::PipelineScheduler pipeline_scheduler; + std::variant + conv_specialization; + builder::GemmPadding padding; +}; + +// Provides human-readable descriptions of ConvBuilder configurations. +struct ConvDescription +{ + ConvSignatureInfo signature; + GemmAlgorithmInfo algorithm; + + // Brief one-line summary + std::string brief() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " << signature.direction << " convolution"; + return oss.str(); + } + + // Detailed hierarchical description + std::string detailed() const + { + TreeFormatter f; + f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel"); + f.writeLine(1, "Signature"); + f.writeLine(2, "Tensor Type: ", signature.data_type); + f.writeLine(2, "Memory Layout: ", signature.layout); + f.writeLine(2, "Input elementwise operation: ", signature.input_element_op); + f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op); + f.writeLast(2, "Output elementwise operation: ", signature.output_element_op); + + f.writeLine(1, "Algorithm"); + // Compute Block section + f.writeLine(2, "Thread block size: ", algorithm.thread_block_size); + f.writeLine(2, + "Data tile size: ", + algorithm.tile_dims.m, + "×", + algorithm.tile_dims.n, + "×", + algorithm.tile_dims.k); + f.writeLine(2, "Gemm padding: ", algorithm.padding); + f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization); + // Pipeline section + f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version); + f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler); + f.writeLine(2, "Warp Gemm parameters: "); + f.writeLine( + 3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n); + f.writeLast(3, + "Number of warp gemm iterations: ", + algorithm.warp_gemm.m_iter, + "×", + algorithm.warp_gemm.n_iter); + + // Memory Access section + f.writeLine(2, "Memory access:"); + + f.writeLine(3, "A Tile transfer: "); + f.writeLine(4, + "Tile dimensions: ", + algorithm.a_tile_transfer.tile_dimensions.k0, + "×", + algorithm.a_tile_transfer.tile_dimensions.m_or_n, + "×", + algorithm.a_tile_transfer.tile_dimensions.k1, + "×"); + f.writeLine( + 4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1); + f.writeLine(4, + "Spatial thread distribution over the data tile: ", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]); + f.writeLine(4, + "The order of accessing data tile axes: ", + algorithm.a_tile_transfer.transfer_params.src_access_order[0], + "×", + algorithm.a_tile_transfer.transfer_params.src_access_order[1], + "×", + algorithm.a_tile_transfer.transfer_params.src_access_order[2]); + f.writeLine(4, + "Vectorized memory access axis index (with contiguous memory): ", + algorithm.a_tile_transfer.transfer_params.src_vector_dim); + f.writeLine(4, + "Vector access (GMEM read) instruction size: ", + algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector); + f.writeLine(4, + "Vector access (LDS write) instruction size: ", + algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + f.writeLast(4, + "LDS data layout padding (to prevent bank conflicts): ", + algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + + f.writeLine(3, "B Tile transfer: "); + f.writeLine(4, + "Tile dimensions: ", + algorithm.b_tile_transfer.tile_dimensions.k0, + "×", + algorithm.b_tile_transfer.tile_dimensions.m_or_n, + "×", + algorithm.b_tile_transfer.tile_dimensions.k1, + "×"); + f.writeLine( + 4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1); + f.writeLine(4, + "Spatial thread distribution over the data tile: ", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]); + f.writeLine(4, + "The order of accessing data tile axes: ", + algorithm.b_tile_transfer.transfer_params.src_access_order[0], + "×", + algorithm.b_tile_transfer.transfer_params.src_access_order[1], + "×", + algorithm.b_tile_transfer.transfer_params.src_access_order[2]); + f.writeLine(4, + "Vectorized memory access axis index (with contiguous memory): ", + algorithm.b_tile_transfer.transfer_params.src_vector_dim); + f.writeLine(4, + "Vector access (GMEM read) instruction size: ", + algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector); + f.writeLine(4, + "Vector access (LDS write) instruction size: ", + algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + f.writeLast(4, + "LDS data layout padding (to prevent bank conflicts): ", + algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + + f.writeLast(3, "C Tile transfer: "); + f.writeLine(4, + "Data shuffle (number of gemm instructions per iteration): ", + algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, + "×", + algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle); + f.writeLine(4, + "Spatial thread distribution used to store data: ", + algorithm.c_tile_transfer.thread_cluster_dims[0], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[1], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[2], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[3]); + f.writeLast(4, + "Vector access (GMEM write) instruction size: ", + algorithm.c_tile_transfer.scalar_per_vector); + f.writeLast(2); + f.writeLast(1); + return f.getString(); + } + + // Educational explanation of optimization choices + std::string explain() const + { + std::ostringstream oss; + // Placeholder for future implementation + return oss.str(); + } + + // Performance characteristics and use case guidance + std::string suggest() const + { + std::ostringstream oss; + // Placeholder for future implementation + return oss.str(); + } +}; + +// Helper concept to detect if a type has InstanceTraits specialization +template +concept HasInstanceTraits = requires { typename InstanceTraits; }; + +// Helper concept to detect ConvBuilder types +template +concept IsConvBuilder = requires { + typename T::Factory; + typename T::Instance; +}; + +// Primary factory function: Create ConvDescription from Instance type directly +template + requires HasInstanceTraits +ConvDescription Describe() +{ + using Traits = ConvTraits; + + return ConvDescription{ + .signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim, + .direction = Traits::direction, + .layout = Traits::layout, + .data_type = Traits::data_type, + .input_element_op = Traits::input_element_op, + .weight_element_op = Traits::weight_element_op, + .output_element_op = Traits::output_element_op}, + .algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size, + .tile_dims = Traits::tile_dims, + .warp_gemm = Traits::warp_gemm, + .a_tile_transfer = Traits::a_tile_transfer, + .b_tile_transfer = Traits::b_tile_transfer, + .c_tile_transfer = Traits::c_tile_transfer, + .pipeline_version = Traits::pipeline_version, + .pipeline_scheduler = Traits::pipeline_scheduler, + .conv_specialization = Traits::conv_specialization, + .padding = Traits::gemm_padding}}; +} + +// Backward compatibility: Create ConvDescription from Builder type +template + requires IsConvBuilder && (!HasInstanceTraits) +ConvDescription Describe() +{ + // Delegate to Instance-based version + using Instance = typename Builder::Instance; + return Describe(); +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index a74d77d155..86cf11f647 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index c863d2306c..e4d154ae10 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -13,7 +13,10 @@ #include #include #include -#include +#include +#include +#include +#include #include #include #include diff --git a/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp new file mode 100644 index 0000000000..6a80a994ee --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include + +namespace ck_tile::reflect { + +// Helper class for formatting hierarchical tree structures with proper indentation +// and tree-drawing characters (├─, └─, │, etc.) +// +// Example Usage: +// +// TreeFormatter f; +// f.writeLine(0, "Root"); +// f.writeLine(1, "Branch 1"); +// f.writeLine(2, "Item 1a"); +// f.writeLast(2, "Item 1b"); +// f.writeLast(1, "Branch 2"); +// f.writeLast(2, "Item 2a"); +// std::cout << f.getString() << "\n"; +// +// Generated Output: +// +// Root +// ├─ Branch 1 +// │ ├─ Item 1a +// │ └─ Item 1b +// └─ Branch 2 +// └─ Item 2a +class TreeFormatter +{ + public: + TreeFormatter() = default; + + // Write a line at the specified indentation level (branch continues after this) + template + void writeLine(int indent_level, Args&&... args) + { + writeLineImpl(indent_level, false, std::forward(args)...); + } + + // Write the last line at the specified indentation level (branch ends) + template + void writeLast(int indent_level, Args&&... args) + { + writeLineImpl(indent_level, true, std::forward(args)...); + } + + // Get the formatted string (removes trailing newline if present) + std::string getString() const + { + std::string result = oss_.str(); + if(!result.empty() && result.back() == '\n') + { + result.pop_back(); + } + return result; + } + + private: + std::ostringstream oss_; + std::vector is_last_at_level_; // Tracks which levels have ended + + // Implementation of line writing with tree symbols + template + void writeLineImpl(int indent_level, bool is_last, Args&&... args) + { + // Ensure we have enough tracking space + if(static_cast(indent_level) >= is_last_at_level_.size()) + { + is_last_at_level_.resize(indent_level + 1, false); + // Level 0 (root) should always be treated as "last" since it has no tree symbols + if(is_last_at_level_.size() > 0) + { + is_last_at_level_[0] = true; + } + } + + // Draw the tree structure + // Start from level 1 (skip level 0 which is the root with no symbols) + for(int i = 1; i < indent_level; ++i) + { + // For all parent levels, draw vertical line or space based on whether they ended + oss_ << (is_last_at_level_[i] ? " " : "│ "); + } + + // Draw the branch symbol for the current level + if(indent_level > 0) + { + oss_ << (is_last ? "└─ " : "├─ "); + } + + // Write the content using fold expression with direct stream insertion + ((oss_ << std::forward(args)), ...); + + oss_ << '\n'; + + // Update tracking for this level AFTER writing the line + // This ensures future lines at deeper levels know if this level ended + is_last_at_level_[indent_level] = is_last; + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 2af10346e5..a58c994288 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + namespace ck_tile::builder { enum class DataType @@ -215,4 +219,275 @@ enum class PipelineScheduler INTERWAVE }; +// ostream operator overloads for enum classes +inline std::ostream& operator<<(std::ostream& os, DataType dt) +{ + using enum DataType; + switch(dt) + { + case FP16: return os << "FP16"; + case FP32: return os << "FP32"; + case BF16: return os << "BF16"; + case FP8: return os << "FP8"; + case I8: return os << "I8"; + case U8: return os << "U8"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) +{ + using enum ConvDirection; + switch(dir) + { + case FORWARD: return os << "Forward"; + case BACKWARD_DATA: return os << "Backward Data"; + case BACKWARD_WEIGHT: return os << "Backward Weight"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout) +{ + using enum GroupConvLayout1D; + switch(layout) + { + case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK"; + case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK"; + case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW"; + case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout) +{ + using enum GroupConvLayout2D; + switch(layout) + { + case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK"; + case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK"; + case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW"; + case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout) +{ + using enum GroupConvLayout3D; + switch(layout) + { + case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK"; + case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK"; + case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW"; + case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op) +{ + using enum FwdGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK: + return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"; + case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle: + return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"; + case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle: + return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"; + case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3: + return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"; + case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor: + return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op) +{ + using enum BwdDataGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD"; + case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle: + return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"; + case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1: + return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op) +{ + using enum BwdWeightGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight"; + case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl"; + case DeviceGroupedConvBwdWeight_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; + case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3: + return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + case DeviceGroupedConvBwdWeight_Wmma_CShuffle: + return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; + case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"; + case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD"; + case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op) +{ + using enum ElementwiseOperation; + switch(op) + { + case BIAS: return os << "BIAS"; + case BIAS_CLAMP: return os << "BIAS_CLAMP"; + case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP"; + case BILINEAR: return os << "BILINEAR"; + case CLAMP: return os << "CLAMP"; + case SCALE: return os << "SCALE"; + case PASS_THROUGH: return os << "PASS_THROUGH"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver) +{ + using enum PipelineVersion; + switch(ver) + { + case V1: return os << "V1"; + case V2: return os << "V2"; + case V3: return os << "V3"; + case V4: return os << "V4"; + case V5: return os << "V5"; + case WEIGHT_ONLY: return os << "WEIGHT_ONLY"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec) +{ + using enum GemmSpecialization; + switch(spec) + { + case Default: return os << "Default"; + case MPadding: return os << "MPadding"; + case NPadding: return os << "NPadding"; + case KPadding: return os << "KPadding"; + case MNPadding: return os << "MNPadding"; + case MKPadding: return os << "MKPadding"; + case NKPadding: return os << "NKPadding"; + case MNKPadding: return os << "MNKPadding"; + case OPadding: return os << "OPadding"; + case MOPadding: return os << "MOPadding"; + case NOPadding: return os << "NOPadding"; + case KOPadding: return os << "KOPadding"; + case MNOPadding: return os << "MNOPadding"; + case MKOPadding: return os << "MKOPadding"; + case NKOPadding: return os << "NKOPadding"; + case MNKOPadding: return os << "MNKOPadding"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec) +{ + using enum ConvFwdSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + case FILTER_3x3: return os << "FILTER_3x3"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec) +{ + using enum ConvBwdDataSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec) +{ + using enum ConvBwdWeightSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0"; + case ODD_C: return os << "ODD_C"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GemmPadding padding) +{ + using enum GemmPadding; + switch(padding) + { + case DEFAULT: return os << "DEFAULT"; + case M_PADDING: return os << "M_PADDING"; + case N_PADDING: return os << "N_PADDING"; + case K_PADDING: return os << "K_PADDING"; + case MN_PADDING: return os << "MN_PADDING"; + case MK_PADDING: return os << "MK_PADDING"; + case NK_PADDING: return os << "NK_PADDING"; + case MNK_PADDING: return os << "MNK_PADDING"; + case O_PADDING: return os << "O_PADDING"; + case MO_PADDING: return os << "MO_PADDING"; + case NO_PADDING: return os << "NO_PADDING"; + case KO_PADDING: return os << "KO_PADDING"; + case MNO_PADDING: return os << "MNO_PADDING"; + case MKO_PADDING: return os << "MKO_PADDING"; + case NKO_PADDING: return os << "NKO_PADDING"; + case MNKO_PADDING: return os << "MNKO_PADDING"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched) +{ + using enum PipelineScheduler; + switch(sched) + { + case DEFAULT: return os << "DEFAULT"; + case INTRAWAVE: return os << "INTRAWAVE"; + case INTERWAVE: return os << "INTERWAVE"; + default: return os << "Unknown"; + } +} + +// ostream operator overload for std::variant of layout types +inline std::ostream& +operator<<(std::ostream& os, + const std::variant& layout) +{ + std::visit([&os](const auto& l) { os << l; }, layout); + return os; +} + +// ostream operator overload for std::variant of convolution specializations +inline std::ostream& operator<<(std::ostream& os, + const std::variant& spec) +{ + std::visit([&os](const auto& s) { os << s; }, spec); + return os; +} + } // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 0cb3237f8c..b776edbcde 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -67,6 +67,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test add_ck_builder_test(test_conv_traits conv/test_conv_traits.cpp) +add_ck_builder_test(test_conv_description + test_conv_description.cpp) + # Function to add all test_ckb targets to a list function(collect_test_ckb_targets result_var) # Get all targets in current directory diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp new file mode 100644 index 0000000000..97af4af795 --- /dev/null +++ b/experimental/builder/test/test_conv_description.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include "testing_utils.hpp" +#include "impl/conv_signature_types.hpp" +#include "impl/conv_algorithm_types.hpp" + +namespace { + +namespace ckb = ck_tile::builder; +namespace ckr = ck_tile::reflect::conv; +namespace ckt = ck_tile::test; + +// Defines the signature of the convolution operation to be tested. +// This includes dimensionality, direction, data layout, and data type. +struct ConvSignature +{ + int spatial_dim = 2; + ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; + ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + ckb::DataType data_type = ckb::DataType::FP16; + ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH; + ckb::GroupConvDeviceOp device_operation = + ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; +}; +static_assert(ckb::ConvSignatureDescriptor); + +struct DefaultAlgorithm +{ + ckb::test::ThreadBlock thread_block{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; + + ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 16, + .n_per_xdl = 16, + .m_xdl_per_wave = 4, + .n_xdl_per_wave = 4}; + + ckb::test::BlockTransferABC block_transfer{ + .block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8}, + .block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {.order = {0, 1, 2}}, + .block_transfer_access_order_b = {.order = {0, 1, 2}}, + .src_access_order_a = {.order = {0, 1, 2}}, + .src_access_order_b = {.order = {0, 1, 2}}}; + + ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT; + ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; + ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, + .scheduler = ckb::PipelineScheduler::INTRAWAVE}; +}; +static_assert(ckb::ConvAlgorithmDescriptor); + +TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription) +{ + static constexpr const ConvSignature SIGNATURE; + static constexpr const DefaultAlgorithm ALGORITHM; + using Builder = ckb::ConvBuilder; + EXPECT_THAT(ckr::Describe().brief(), ckt::StringEqWithDiff("2D Forward convolution")); +} + +TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) +{ + static constexpr const ConvSignature SIGNATURE; + static constexpr const DefaultAlgorithm ALGORITHM; + using Builder = ckb::ConvBuilder; + EXPECT_THAT(ckr::Describe().detailed(), + ckt::StringEqWithDiff( // + "2D Forward Convolution Kernel\n" + "├─ Signature\n" + "│ ├─ Tensor Type: FP16\n" + "│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n" + "│ ├─ Input elementwise operation: PASS_THROUGH\n" + "│ ├─ Weights elementwise operation: PASS_THROUGH\n" + "│ └─ Output elementwise operation: PASS_THROUGH\n" + "├─ Algorithm\n" + "│ ├─ Thread block size: 256\n" + "│ ├─ Data tile size: 256×256×32\n" + "│ ├─ Gemm padding: DEFAULT\n" + "│ ├─ Convolution specialization: DEFAULT\n" + "│ ├─ Pipeline version: V4\n" + "│ ├─ Pipeline scheduler: INTRAWAVE\n" + "│ ├─ Warp Gemm parameters: \n" + "│ │ ├─ subtile size: 16×16\n" + "│ │ └─ Number of warp gemm iterations: 4×4\n" + "│ ├─ Memory access:\n" + "│ │ ├─ A Tile transfer: \n" + "│ │ │ ├─ Tile dimensions: 4×256×8×\n" + "│ │ │ ├─ The innermost K subdimension size: 8\n" + "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" + "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + "│ │ ├─ B Tile transfer: \n" + "│ │ │ ├─ Tile dimensions: 4×256×8×\n" + "│ │ │ ├─ The innermost K subdimension size: 8\n" + "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" + "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + "│ │ └─ C Tile transfer: \n" + "│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + "│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + "│ │ └─ Vector access (GMEM write) instruction size: 8\n" + "│ └─ \n" + "└─ ")); +} + +// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory +// does not have a specialization for backward data convolutions. The test fails with: +// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'" +// +// To enable this test, a ConvFactory specialization for backward data operations must be +// implemented first. +// +// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription) +// { +// struct BackwardDataSignature +// { +// int spatial_dim = 2; +// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA; +// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; +// ckb::DataType data_type = ckb::DataType::FP16; +// ckb::ElementwiseOperation elementwise_operation = +// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation = +// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; +// }; +// static_assert(ckb::ConvSignatureDescriptor); +// +// static constexpr const BackwardDataSignature SIGNATURE; +// static constexpr const DefaultAlgorithm ALGORITHM; +// using Builder = ckb::ConvBuilder; +// +// // Verify Brief works +// EXPECT_THAT(ckr::Describe().brief(), +// ckt::StringEqWithDiff("2D Backward Data convolution")); +// +// // Verify detailed works - to be updated once ConvFactory is implemented +// EXPECT_THAT(ckr::Describe().detailed(), +// ckt::StringEqWithDiff("PLACEHOLDER")); +// } +} // namespace