[rocm-libraries] ROCm/rocm-libraries#5083 (commit d65061b)

[CK_BUILDER] Simplify the TreeFormatter.

My original design wrote each line streaming, so developers had to keep
track of the indentation depth and remember when to use `writelast` for
the last element at a depth. This was a source of a lot of cosmetic
output errors, and that is likely to get more complicated as we add
optional branches.

We switch to a tree-building interface with a simple `add` method. The
only cost is that we have to defer string building, which is a good
tradeoff for our use case.

Tested with `ninja smoke-builder`.
This commit is contained in:
John Shumway
2026-03-07 18:06:33 +00:00
committed by assistant-librarian[bot]
parent e0d11b969b
commit a7b894544e
3 changed files with 164 additions and 209 deletions

View File

@@ -21,7 +21,7 @@ This template is common for XDL and WMMA, forward and backward weight kernels. `
- **`conv_description.hpp`**: The main entry point. Contains the `ConvDescription` struct and the `describe()` factory function.
- **`conv_traits.hpp`**: Home of the `ConvTraits` template, which is the core of the property extraction mechanism.
- **`tree_formatter.hpp`**: A simple utility for generating the indented, tree-like format used in the `detailed()` description.
- **`tree_formatter.hpp`**: A tree-building utility that generates indented, tree-like output for the `detailed()` description.
## Usage

View File

@@ -64,175 +64,150 @@ class ConvDescription : public Description
/// @return A multi-line tree-formatted description covering signature and algorithm details
std::string detailed() const override
{
TreeFormatter f;
f.writeLine(0, traits_.spatial_dim, "D ", traits_.direction, " Convolution Kernel");
f.writeLine(1, "Signature");
f.writeLine(2, "Tensor Type: ", traits_.data_type);
f.writeLine(2, "Input Layout: ", traits_.layout[0]);
f.writeLine(2, "Weight Layout: ", traits_.layout[1]);
f.writeLine(2, "Output Layout: ", traits_.layout[2]);
f.writeLine(2, "Input elementwise operation: ", traits_.input_element_op);
f.writeLine(2, "Weights elementwise operation: ", traits_.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", traits_.output_element_op);
TreeFormatter root(traits_.spatial_dim, "D ", traits_.direction, " Convolution Kernel");
f.writeLast(1, "Algorithm");
auto& sig = root.add("Signature");
sig.add("Tensor Type: ", traits_.data_type);
sig.add("Input Layout: ", traits_.layout[0]);
sig.add("Weight Layout: ", traits_.layout[1]);
sig.add("Output Layout: ", traits_.layout[2]);
sig.add("Input elementwise operation: ", traits_.input_element_op);
sig.add("Weights elementwise operation: ", traits_.weight_element_op);
sig.add("Output elementwise operation: ", traits_.output_element_op);
auto& algo = root.add("Algorithm");
// Compute Block section
f.writeLine(2, "Thread block size: ", traits_.thread_block_size);
f.writeLine(2,
"Data tile size: ",
traits_.tile_dims.m,
"×",
traits_.tile_dims.n,
"×",
traits_.tile_dims.k);
algo.add("Thread block size: ", traits_.thread_block_size);
algo.add("Data tile size: ",
traits_.tile_dims.m,
"×",
traits_.tile_dims.n,
"×",
traits_.tile_dims.k);
if(traits_.gemm_padding)
f.writeLine(
2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT));
algo.add("Gemm padding: ",
traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT));
else
f.writeLine(2, "Struct does not contain optional gemm_padding argument");
algo.add("Struct does not contain optional gemm_padding argument");
if(traits_.do_pad_gemm_m)
f.writeLine(2, "Do Padd Gemm M: ", traits_.do_pad_gemm_m.value_or(false));
algo.add("Do Padd Gemm M: ", traits_.do_pad_gemm_m.value_or(false));
if(traits_.do_pad_gemm_n)
f.writeLine(2, "Do Padd Gemm N: ", traits_.do_pad_gemm_n.value_or(false));
f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization);
algo.add("Do Padd Gemm N: ", traits_.do_pad_gemm_n.value_or(false));
algo.add("Convolution specialization: ", traits_.conv_specialization);
// Pipeline section
f.writeLine(2, "Pipeline version: ", traits_.pipeline_version);
f.writeLine(2, "Pipeline scheduler: ", traits_.pipeline_scheduler);
f.writeLine(2, "Warp Gemm parameters: ");
f.writeLine(3, "subtile size: ", traits_.warp_gemm.gemm_m, "×", traits_.warp_gemm.gemm_n);
f.writeLast(3,
"Number of warp gemm iterations: ",
traits_.warp_gemm.m_iter,
"×",
traits_.warp_gemm.n_iter);
algo.add("Pipeline version: ", traits_.pipeline_version);
algo.add("Pipeline scheduler: ", traits_.pipeline_scheduler);
auto& warpGemm = algo.add("Warp Gemm parameters: ");
warpGemm.add("subtile size: ", traits_.warp_gemm.gemm_m, "×", traits_.warp_gemm.gemm_n);
warpGemm.add("Number of warp gemm iterations: ",
traits_.warp_gemm.m_iter,
"×",
traits_.warp_gemm.n_iter);
// Memory Access section
f.writeLine(2, "Memory access:");
auto& memAccess = algo.add("Memory access:");
f.writeLine(3, "A Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
traits_.a_tile_transfer.tile_dimensions.k0,
"×",
traits_.a_tile_transfer.tile_dimensions.m_or_n,
"×",
traits_.a_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", traits_.a_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
traits_.a_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
traits_.a_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
traits_.a_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
traits_.a_tile_transfer.transfer_params.src_access_order[0],
"×",
traits_.a_tile_transfer.transfer_params.src_access_order[1],
"×",
traits_.a_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
traits_.a_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
traits_.a_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
auto& aTile = memAccess.add("A Tile transfer: ");
aTile.add("Tile dimensions: ",
traits_.a_tile_transfer.tile_dimensions.k0,
"×",
traits_.a_tile_transfer.tile_dimensions.m_or_n,
"×",
traits_.a_tile_transfer.tile_dimensions.k1,
"×");
aTile.add("The innermost K subdimension size: ",
traits_.a_tile_transfer.transfer_params.k1);
aTile.add("Spatial thread distribution over the data tile: ",
traits_.a_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
traits_.a_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
traits_.a_tile_transfer.transfer_params.thread_cluster_order[2]);
aTile.add("The order of accessing data tile axes: ",
traits_.a_tile_transfer.transfer_params.src_access_order[0],
"×",
traits_.a_tile_transfer.transfer_params.src_access_order[1],
"×",
traits_.a_tile_transfer.transfer_params.src_access_order[2]);
aTile.add("Vectorized memory access axis index (with contiguous memory): ",
traits_.a_tile_transfer.transfer_params.src_vector_dim);
aTile.add("Vector access (GMEM read) instruction size: ",
traits_.a_tile_transfer.transfer_params.src_scalar_per_vector);
aTile.add("Vector access (LDS write) instruction size: ",
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
aTile.add("LDS data layout padding (to prevent bank conflicts): ",
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLine(3, "B Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
traits_.b_tile_transfer.tile_dimensions.k0,
"×",
traits_.b_tile_transfer.tile_dimensions.m_or_n,
"×",
traits_.b_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", traits_.b_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
traits_.b_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
traits_.b_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
traits_.b_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
traits_.b_tile_transfer.transfer_params.src_access_order[0],
"×",
traits_.b_tile_transfer.transfer_params.src_access_order[1],
"×",
traits_.b_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
traits_.b_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
traits_.b_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
auto& bTile = memAccess.add("B Tile transfer: ");
bTile.add("Tile dimensions: ",
traits_.b_tile_transfer.tile_dimensions.k0,
"×",
traits_.b_tile_transfer.tile_dimensions.m_or_n,
"×",
traits_.b_tile_transfer.tile_dimensions.k1,
"×");
bTile.add("The innermost K subdimension size: ",
traits_.b_tile_transfer.transfer_params.k1);
bTile.add("Spatial thread distribution over the data tile: ",
traits_.b_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
traits_.b_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
traits_.b_tile_transfer.transfer_params.thread_cluster_order[2]);
bTile.add("The order of accessing data tile axes: ",
traits_.b_tile_transfer.transfer_params.src_access_order[0],
"×",
traits_.b_tile_transfer.transfer_params.src_access_order[1],
"×",
traits_.b_tile_transfer.transfer_params.src_access_order[2]);
bTile.add("Vectorized memory access axis index (with contiguous memory): ",
traits_.b_tile_transfer.transfer_params.src_vector_dim);
bTile.add("Vector access (GMEM read) instruction size: ",
traits_.b_tile_transfer.transfer_params.src_scalar_per_vector);
bTile.add("Vector access (LDS write) instruction size: ",
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
bTile.add("LDS data layout padding (to prevent bank conflicts): ",
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(3, "C Tile transfer: ");
f.writeLine(4,
"Data shuffle (number of gemm instructions per iteration): ",
traits_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
"×",
traits_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
f.writeLine(4,
"Spatial thread distribution used to store data: ",
traits_.c_tile_transfer.thread_cluster_dims[0],
"×",
traits_.c_tile_transfer.thread_cluster_dims[1],
"×",
traits_.c_tile_transfer.thread_cluster_dims[2],
"×",
traits_.c_tile_transfer.thread_cluster_dims[3]);
f.writeLast(4,
"Vector access (GMEM write) instruction size: ",
traits_.c_tile_transfer.scalar_per_vector);
auto& cTile = memAccess.add("C Tile transfer: ");
cTile.add("Data shuffle (number of gemm instructions per iteration): ",
traits_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
"×",
traits_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
cTile.add("Spatial thread distribution used to store data: ",
traits_.c_tile_transfer.thread_cluster_dims[0],
"×",
traits_.c_tile_transfer.thread_cluster_dims[1],
"×",
traits_.c_tile_transfer.thread_cluster_dims[2],
"×",
traits_.c_tile_transfer.thread_cluster_dims[3]);
cTile.add("Vector access (GMEM write) instruction size: ",
traits_.c_tile_transfer.scalar_per_vector);
if(traits_.num_gemm_k_prefetch_stage)
f.writeLine(
2, "Num gemm k prefetch stage: ", traits_.num_gemm_k_prefetch_stage.value_or(0));
algo.add("Num gemm k prefetch stage: ", traits_.num_gemm_k_prefetch_stage.value_or(0));
else
f.writeLine(2,
"Struct does not contain optional "
"num_gemm_k_prefetch_stage parameter");
algo.add("Struct does not contain optional "
"num_gemm_k_prefetch_stage parameter");
if(traits_.max_transpose_transfer_src_scalar_per_vector)
f.writeLine(2,
"Max Transpose transfer scr scalar per vector: ",
traits_.max_transpose_transfer_src_scalar_per_vector.value_or(0));
algo.add("Max Transpose transfer scr scalar per vector: ",
traits_.max_transpose_transfer_src_scalar_per_vector.value_or(0));
else
f.writeLine(2,
"Struct does not contain optional "
"max_transpose_transfer_src_scalar_per_vector parameter");
algo.add("Struct does not contain optional "
"max_transpose_transfer_src_scalar_per_vector parameter");
if(traits_.max_transpose_transfer_dst_scalar_per_vector)
f.writeLine(2,
"Max Transpose dst scalar per vector: ",
traits_.max_transpose_transfer_dst_scalar_per_vector.value_or(0));
algo.add("Max Transpose dst scalar per vector: ",
traits_.max_transpose_transfer_dst_scalar_per_vector.value_or(0));
else
f.writeLine(2,
"Struct does not contain optional "
"max_transpose_transfer_dst_scalar_per_vector parameter");
algo.add("Struct does not contain optional "
"max_transpose_transfer_dst_scalar_per_vector parameter");
if(traits_.num_groups_to_merge)
f.writeLast(2, "Num groups to merge: ", traits_.num_groups_to_merge.value_or(0));
algo.add("Num groups to merge: ", traits_.num_groups_to_merge.value_or(0));
else
f.writeLast(2, "Struct does not contain optional num_groups_to_merge parameter");
algo.add("Struct does not contain optional num_groups_to_merge parameter");
return f.getString();
return root.getString();
}
/// @brief Generate a string representation of the instance

View File

@@ -3,26 +3,27 @@
#pragma once
#include <deque>
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
namespace ck_tile::reflect {
// Helper class for formatting hierarchical tree structures with proper indentation
// and tree-drawing characters (├─, └─, │, etc.)
// Tree-node class for building hierarchical tree structures, then rendering them
// with proper indentation and tree-drawing characters (├─, └─, │, etc.)
//
// Unlike a streaming API, the tree is built first and rendered afterwards,
// so last-child status is determined automatically.
//
// Example Usage:
//
// TreeFormatter f;
// f.writeLine(0, "Root");
// f.writeLine(1, "Branch 1");
// f.writeLine(2, "Item 1a");
// f.writeLast(2, "Item 1b");
// f.writeLast(1, "Branch 2");
// f.writeLast(2, "Item 2a");
// std::cout << f.getString() << "\n";
// TreeFormatter root("Root");
// auto& b1 = root.add("Branch 1");
// b1.add("Item 1a");
// b1.add("Item 1b");
// auto& b2 = root.add("Branch 2");
// b2.add("Item 2a");
// std::cout << root.getString() << "\n";
//
// Generated Output:
//
@@ -35,74 +36,53 @@ namespace ck_tile::reflect {
class TreeFormatter
{
public:
TreeFormatter() = default;
// Write a line at the specified indentation level (branch continues after this)
// Construct a node with content built from the given arguments
template <typename... Args>
void writeLine(int indent_level, Args&&... args)
explicit TreeFormatter(Args&&... args)
{
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
std::ostringstream oss;
((oss << std::forward<Args>(args)), ...);
content_ = oss.str();
}
// Write the last line at the specified indentation level (branch ends)
// Add a child node, returns a reference to it for further nesting
template <typename... Args>
void writeLast(int indent_level, Args&&... args)
TreeFormatter& add(Args&&... args)
{
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
children_.emplace_back(std::forward<Args>(args)...);
return children_.back();
}
// Get the formatted string (removes trailing newline if present)
// Render the full tree to a string
std::string getString() const
{
std::string result = oss_.str();
if(!result.empty() && result.back() == '\n')
std::ostringstream oss;
oss << content_;
for(size_t i = 0; i < children_.size(); ++i)
{
result.pop_back();
oss << '\n';
children_[i].renderChild(oss, "", i == children_.size() - 1);
}
return result;
return oss.str();
}
private:
std::ostringstream oss_;
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
std::string content_;
// std::deque preserves references to existing elements on push_back/emplace_back,
// unlike std::vector which may reallocate. This allows add() to safely return
// a reference to the newly added child for further nesting.
std::deque<TreeFormatter> children_;
// Implementation of line writing with tree symbols
template <typename... Args>
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
// Recursive render helper
void renderChild(std::ostringstream& oss, const std::string& prefix, bool is_last) const
{
// Ensure we have enough tracking space
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
oss << prefix << (is_last ? "└─ " : "├─ ") << content_;
std::string child_prefix = prefix + (is_last ? " " : "");
for(size_t i = 0; i < children_.size(); ++i)
{
is_last_at_level_.resize(indent_level + 1, false);
// Level 0 (root) should always be treated as "last" since it has no tree symbols
if(is_last_at_level_.size() > 0)
{
is_last_at_level_[0] = true;
}
oss << '\n';
children_[i].renderChild(oss, child_prefix, i == children_.size() - 1);
}
// Draw the tree structure
// Start from level 1 (skip level 0 which is the root with no symbols)
for(int i = 1; i < indent_level; ++i)
{
// For all parent levels, draw vertical line or space based on whether they ended
oss_ << (is_last_at_level_[i] ? " " : "");
}
// Draw the branch symbol for the current level
if(indent_level > 0)
{
oss_ << (is_last ? "└─ " : "├─ ");
}
// Write the content using fold expression with direct stream insertion
((oss_ << std::forward<Args>(args)), ...);
oss_ << '\n';
// Update tracking for this level AFTER writing the line
// This ensures future lines at deeper levels know if this level ended
is_last_at_level_[indent_level] = is_last;
}
};