diff --git a/experimental/builder/include/ck_tile/builder/reflect/README.md b/experimental/builder/include/ck_tile/builder/reflect/README.md index af028f5d5a..d15912dd28 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/README.md +++ b/experimental/builder/include/ck_tile/builder/reflect/README.md @@ -21,7 +21,7 @@ This template is common for XDL and WMMA, forward and backward weight kernels. ` - **`conv_description.hpp`**: The main entry point. Contains the `ConvDescription` struct and the `describe()` factory function. - **`conv_traits.hpp`**: Home of the `ConvTraits` template, which is the core of the property extraction mechanism. -- **`tree_formatter.hpp`**: A simple utility for generating the indented, tree-like format used in the `detailed()` description. +- **`tree_formatter.hpp`**: A tree-building utility that generates indented, tree-like output for the `detailed()` description. ## Usage diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index c14cfce63c..01069c1140 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -64,175 +64,150 @@ class ConvDescription : public Description /// @return A multi-line tree-formatted description covering signature and algorithm details std::string detailed() const override { - TreeFormatter f; - f.writeLine(0, traits_.spatial_dim, "D ", traits_.direction, " Convolution Kernel"); - f.writeLine(1, "Signature"); - f.writeLine(2, "Tensor Type: ", traits_.data_type); - f.writeLine(2, "Input Layout: ", traits_.layout[0]); - f.writeLine(2, "Weight Layout: ", traits_.layout[1]); - f.writeLine(2, "Output Layout: ", traits_.layout[2]); - f.writeLine(2, "Input elementwise operation: ", traits_.input_element_op); - f.writeLine(2, "Weights elementwise operation: ", traits_.weight_element_op); - f.writeLast(2, "Output elementwise operation: ", traits_.output_element_op); + TreeFormatter root(traits_.spatial_dim, "D ", traits_.direction, " Convolution Kernel"); - f.writeLast(1, "Algorithm"); + auto& sig = root.add("Signature"); + sig.add("Tensor Type: ", traits_.data_type); + sig.add("Input Layout: ", traits_.layout[0]); + sig.add("Weight Layout: ", traits_.layout[1]); + sig.add("Output Layout: ", traits_.layout[2]); + sig.add("Input elementwise operation: ", traits_.input_element_op); + sig.add("Weights elementwise operation: ", traits_.weight_element_op); + sig.add("Output elementwise operation: ", traits_.output_element_op); + + auto& algo = root.add("Algorithm"); // Compute Block section - f.writeLine(2, "Thread block size: ", traits_.thread_block_size); - f.writeLine(2, - "Data tile size: ", - traits_.tile_dims.m, - "×", - traits_.tile_dims.n, - "×", - traits_.tile_dims.k); + algo.add("Thread block size: ", traits_.thread_block_size); + algo.add("Data tile size: ", + traits_.tile_dims.m, + "×", + traits_.tile_dims.n, + "×", + traits_.tile_dims.k); if(traits_.gemm_padding) - f.writeLine( - 2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT)); + algo.add("Gemm padding: ", + traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT)); else - f.writeLine(2, "Struct does not contain optional gemm_padding argument"); + algo.add("Struct does not contain optional gemm_padding argument"); if(traits_.do_pad_gemm_m) - f.writeLine(2, "Do Padd Gemm M: ", traits_.do_pad_gemm_m.value_or(false)); + algo.add("Do Padd Gemm M: ", traits_.do_pad_gemm_m.value_or(false)); if(traits_.do_pad_gemm_n) - f.writeLine(2, "Do Padd Gemm N: ", traits_.do_pad_gemm_n.value_or(false)); - f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization); + algo.add("Do Padd Gemm N: ", traits_.do_pad_gemm_n.value_or(false)); + algo.add("Convolution specialization: ", traits_.conv_specialization); // Pipeline section - f.writeLine(2, "Pipeline version: ", traits_.pipeline_version); - f.writeLine(2, "Pipeline scheduler: ", traits_.pipeline_scheduler); - f.writeLine(2, "Warp Gemm parameters: "); - f.writeLine(3, "subtile size: ", traits_.warp_gemm.gemm_m, "×", traits_.warp_gemm.gemm_n); - f.writeLast(3, - "Number of warp gemm iterations: ", - traits_.warp_gemm.m_iter, - "×", - traits_.warp_gemm.n_iter); + algo.add("Pipeline version: ", traits_.pipeline_version); + algo.add("Pipeline scheduler: ", traits_.pipeline_scheduler); + auto& warpGemm = algo.add("Warp Gemm parameters: "); + warpGemm.add("subtile size: ", traits_.warp_gemm.gemm_m, "×", traits_.warp_gemm.gemm_n); + warpGemm.add("Number of warp gemm iterations: ", + traits_.warp_gemm.m_iter, + "×", + traits_.warp_gemm.n_iter); // Memory Access section - f.writeLine(2, "Memory access:"); + auto& memAccess = algo.add("Memory access:"); - f.writeLine(3, "A Tile transfer: "); - f.writeLine(4, - "Tile dimensions: ", - traits_.a_tile_transfer.tile_dimensions.k0, - "×", - traits_.a_tile_transfer.tile_dimensions.m_or_n, - "×", - traits_.a_tile_transfer.tile_dimensions.k1, - "×"); - f.writeLine( - 4, "The innermost K subdimension size: ", traits_.a_tile_transfer.transfer_params.k1); - f.writeLine(4, - "Spatial thread distribution over the data tile: ", - traits_.a_tile_transfer.transfer_params.thread_cluster_order[0], - "×", - traits_.a_tile_transfer.transfer_params.thread_cluster_order[1], - "×", - traits_.a_tile_transfer.transfer_params.thread_cluster_order[2]); - f.writeLine(4, - "The order of accessing data tile axes: ", - traits_.a_tile_transfer.transfer_params.src_access_order[0], - "×", - traits_.a_tile_transfer.transfer_params.src_access_order[1], - "×", - traits_.a_tile_transfer.transfer_params.src_access_order[2]); - f.writeLine(4, - "Vectorized memory access axis index (with contiguous memory): ", - traits_.a_tile_transfer.transfer_params.src_vector_dim); - f.writeLine(4, - "Vector access (GMEM read) instruction size: ", - traits_.a_tile_transfer.transfer_params.src_scalar_per_vector); - f.writeLine(4, - "Vector access (LDS write) instruction size: ", - traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); - f.writeLast(4, - "LDS data layout padding (to prevent bank conflicts): ", - traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + auto& aTile = memAccess.add("A Tile transfer: "); + aTile.add("Tile dimensions: ", + traits_.a_tile_transfer.tile_dimensions.k0, + "×", + traits_.a_tile_transfer.tile_dimensions.m_or_n, + "×", + traits_.a_tile_transfer.tile_dimensions.k1, + "×"); + aTile.add("The innermost K subdimension size: ", + traits_.a_tile_transfer.transfer_params.k1); + aTile.add("Spatial thread distribution over the data tile: ", + traits_.a_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + traits_.a_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + traits_.a_tile_transfer.transfer_params.thread_cluster_order[2]); + aTile.add("The order of accessing data tile axes: ", + traits_.a_tile_transfer.transfer_params.src_access_order[0], + "×", + traits_.a_tile_transfer.transfer_params.src_access_order[1], + "×", + traits_.a_tile_transfer.transfer_params.src_access_order[2]); + aTile.add("Vectorized memory access axis index (with contiguous memory): ", + traits_.a_tile_transfer.transfer_params.src_vector_dim); + aTile.add("Vector access (GMEM read) instruction size: ", + traits_.a_tile_transfer.transfer_params.src_scalar_per_vector); + aTile.add("Vector access (LDS write) instruction size: ", + traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + aTile.add("LDS data layout padding (to prevent bank conflicts): ", + traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); - f.writeLine(3, "B Tile transfer: "); - f.writeLine(4, - "Tile dimensions: ", - traits_.b_tile_transfer.tile_dimensions.k0, - "×", - traits_.b_tile_transfer.tile_dimensions.m_or_n, - "×", - traits_.b_tile_transfer.tile_dimensions.k1, - "×"); - f.writeLine( - 4, "The innermost K subdimension size: ", traits_.b_tile_transfer.transfer_params.k1); - f.writeLine(4, - "Spatial thread distribution over the data tile: ", - traits_.b_tile_transfer.transfer_params.thread_cluster_order[0], - "×", - traits_.b_tile_transfer.transfer_params.thread_cluster_order[1], - "×", - traits_.b_tile_transfer.transfer_params.thread_cluster_order[2]); - f.writeLine(4, - "The order of accessing data tile axes: ", - traits_.b_tile_transfer.transfer_params.src_access_order[0], - "×", - traits_.b_tile_transfer.transfer_params.src_access_order[1], - "×", - traits_.b_tile_transfer.transfer_params.src_access_order[2]); - f.writeLine(4, - "Vectorized memory access axis index (with contiguous memory): ", - traits_.b_tile_transfer.transfer_params.src_vector_dim); - f.writeLine(4, - "Vector access (GMEM read) instruction size: ", - traits_.b_tile_transfer.transfer_params.src_scalar_per_vector); - f.writeLine(4, - "Vector access (LDS write) instruction size: ", - traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); - f.writeLast(4, - "LDS data layout padding (to prevent bank conflicts): ", - traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + auto& bTile = memAccess.add("B Tile transfer: "); + bTile.add("Tile dimensions: ", + traits_.b_tile_transfer.tile_dimensions.k0, + "×", + traits_.b_tile_transfer.tile_dimensions.m_or_n, + "×", + traits_.b_tile_transfer.tile_dimensions.k1, + "×"); + bTile.add("The innermost K subdimension size: ", + traits_.b_tile_transfer.transfer_params.k1); + bTile.add("Spatial thread distribution over the data tile: ", + traits_.b_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + traits_.b_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + traits_.b_tile_transfer.transfer_params.thread_cluster_order[2]); + bTile.add("The order of accessing data tile axes: ", + traits_.b_tile_transfer.transfer_params.src_access_order[0], + "×", + traits_.b_tile_transfer.transfer_params.src_access_order[1], + "×", + traits_.b_tile_transfer.transfer_params.src_access_order[2]); + bTile.add("Vectorized memory access axis index (with contiguous memory): ", + traits_.b_tile_transfer.transfer_params.src_vector_dim); + bTile.add("Vector access (GMEM read) instruction size: ", + traits_.b_tile_transfer.transfer_params.src_scalar_per_vector); + bTile.add("Vector access (LDS write) instruction size: ", + traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + bTile.add("LDS data layout padding (to prevent bank conflicts): ", + traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); - f.writeLast(3, "C Tile transfer: "); - f.writeLine(4, - "Data shuffle (number of gemm instructions per iteration): ", - traits_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, - "×", - traits_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle); - f.writeLine(4, - "Spatial thread distribution used to store data: ", - traits_.c_tile_transfer.thread_cluster_dims[0], - "×", - traits_.c_tile_transfer.thread_cluster_dims[1], - "×", - traits_.c_tile_transfer.thread_cluster_dims[2], - "×", - traits_.c_tile_transfer.thread_cluster_dims[3]); - f.writeLast(4, - "Vector access (GMEM write) instruction size: ", - traits_.c_tile_transfer.scalar_per_vector); + auto& cTile = memAccess.add("C Tile transfer: "); + cTile.add("Data shuffle (number of gemm instructions per iteration): ", + traits_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, + "×", + traits_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle); + cTile.add("Spatial thread distribution used to store data: ", + traits_.c_tile_transfer.thread_cluster_dims[0], + "×", + traits_.c_tile_transfer.thread_cluster_dims[1], + "×", + traits_.c_tile_transfer.thread_cluster_dims[2], + "×", + traits_.c_tile_transfer.thread_cluster_dims[3]); + cTile.add("Vector access (GMEM write) instruction size: ", + traits_.c_tile_transfer.scalar_per_vector); if(traits_.num_gemm_k_prefetch_stage) - f.writeLine( - 2, "Num gemm k prefetch stage: ", traits_.num_gemm_k_prefetch_stage.value_or(0)); + algo.add("Num gemm k prefetch stage: ", traits_.num_gemm_k_prefetch_stage.value_or(0)); else - f.writeLine(2, - "Struct does not contain optional " - "num_gemm_k_prefetch_stage parameter"); + algo.add("Struct does not contain optional " + "num_gemm_k_prefetch_stage parameter"); if(traits_.max_transpose_transfer_src_scalar_per_vector) - f.writeLine(2, - "Max Transpose transfer scr scalar per vector: ", - traits_.max_transpose_transfer_src_scalar_per_vector.value_or(0)); + algo.add("Max Transpose transfer scr scalar per vector: ", + traits_.max_transpose_transfer_src_scalar_per_vector.value_or(0)); else - f.writeLine(2, - "Struct does not contain optional " - "max_transpose_transfer_src_scalar_per_vector parameter"); + algo.add("Struct does not contain optional " + "max_transpose_transfer_src_scalar_per_vector parameter"); if(traits_.max_transpose_transfer_dst_scalar_per_vector) - f.writeLine(2, - "Max Transpose dst scalar per vector: ", - traits_.max_transpose_transfer_dst_scalar_per_vector.value_or(0)); + algo.add("Max Transpose dst scalar per vector: ", + traits_.max_transpose_transfer_dst_scalar_per_vector.value_or(0)); else - f.writeLine(2, - "Struct does not contain optional " - "max_transpose_transfer_dst_scalar_per_vector parameter"); + algo.add("Struct does not contain optional " + "max_transpose_transfer_dst_scalar_per_vector parameter"); if(traits_.num_groups_to_merge) - f.writeLast(2, "Num groups to merge: ", traits_.num_groups_to_merge.value_or(0)); + algo.add("Num groups to merge: ", traits_.num_groups_to_merge.value_or(0)); else - f.writeLast(2, "Struct does not contain optional num_groups_to_merge parameter"); + algo.add("Struct does not contain optional num_groups_to_merge parameter"); - return f.getString(); + return root.getString(); } /// @brief Generate a string representation of the instance diff --git a/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp index 84c4ad1ddf..8657e8dd2c 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp @@ -3,26 +3,27 @@ #pragma once +#include #include #include -#include -#include namespace ck_tile::reflect { -// Helper class for formatting hierarchical tree structures with proper indentation -// and tree-drawing characters (├─, └─, │, etc.) +// Tree-node class for building hierarchical tree structures, then rendering them +// with proper indentation and tree-drawing characters (├─, └─, │, etc.) +// +// Unlike a streaming API, the tree is built first and rendered afterwards, +// so last-child status is determined automatically. // // Example Usage: // -// TreeFormatter f; -// f.writeLine(0, "Root"); -// f.writeLine(1, "Branch 1"); -// f.writeLine(2, "Item 1a"); -// f.writeLast(2, "Item 1b"); -// f.writeLast(1, "Branch 2"); -// f.writeLast(2, "Item 2a"); -// std::cout << f.getString() << "\n"; +// TreeFormatter root("Root"); +// auto& b1 = root.add("Branch 1"); +// b1.add("Item 1a"); +// b1.add("Item 1b"); +// auto& b2 = root.add("Branch 2"); +// b2.add("Item 2a"); +// std::cout << root.getString() << "\n"; // // Generated Output: // @@ -35,74 +36,53 @@ namespace ck_tile::reflect { class TreeFormatter { public: - TreeFormatter() = default; - - // Write a line at the specified indentation level (branch continues after this) + // Construct a node with content built from the given arguments template - void writeLine(int indent_level, Args&&... args) + explicit TreeFormatter(Args&&... args) { - writeLineImpl(indent_level, false, std::forward(args)...); + std::ostringstream oss; + ((oss << std::forward(args)), ...); + content_ = oss.str(); } - // Write the last line at the specified indentation level (branch ends) + // Add a child node, returns a reference to it for further nesting template - void writeLast(int indent_level, Args&&... args) + TreeFormatter& add(Args&&... args) { - writeLineImpl(indent_level, true, std::forward(args)...); + children_.emplace_back(std::forward(args)...); + return children_.back(); } - // Get the formatted string (removes trailing newline if present) + // Render the full tree to a string std::string getString() const { - std::string result = oss_.str(); - if(!result.empty() && result.back() == '\n') + std::ostringstream oss; + oss << content_; + for(size_t i = 0; i < children_.size(); ++i) { - result.pop_back(); + oss << '\n'; + children_[i].renderChild(oss, "", i == children_.size() - 1); } - return result; + return oss.str(); } private: - std::ostringstream oss_; - std::vector is_last_at_level_; // Tracks which levels have ended + std::string content_; + // std::deque preserves references to existing elements on push_back/emplace_back, + // unlike std::vector which may reallocate. This allows add() to safely return + // a reference to the newly added child for further nesting. + std::deque children_; - // Implementation of line writing with tree symbols - template - void writeLineImpl(int indent_level, bool is_last, Args&&... args) + // Recursive render helper + void renderChild(std::ostringstream& oss, const std::string& prefix, bool is_last) const { - // Ensure we have enough tracking space - if(static_cast(indent_level) >= is_last_at_level_.size()) + oss << prefix << (is_last ? "└─ " : "├─ ") << content_; + std::string child_prefix = prefix + (is_last ? " " : "│ "); + for(size_t i = 0; i < children_.size(); ++i) { - is_last_at_level_.resize(indent_level + 1, false); - // Level 0 (root) should always be treated as "last" since it has no tree symbols - if(is_last_at_level_.size() > 0) - { - is_last_at_level_[0] = true; - } + oss << '\n'; + children_[i].renderChild(oss, child_prefix, i == children_.size() - 1); } - - // Draw the tree structure - // Start from level 1 (skip level 0 which is the root with no symbols) - for(int i = 1; i < indent_level; ++i) - { - // For all parent levels, draw vertical line or space based on whether they ended - oss_ << (is_last_at_level_[i] ? " " : "│ "); - } - - // Draw the branch symbol for the current level - if(indent_level > 0) - { - oss_ << (is_last ? "└─ " : "├─ "); - } - - // Write the content using fold expression with direct stream insertion - ((oss_ << std::forward(args)), ...); - - oss_ << '\n'; - - // Update tracking for this level AFTER writing the line - // This ensures future lines at deeper levels know if this level ended - is_last_at_level_[indent_level] = is_last; } };