Extract TreeFormatter from Description class.

Description class was cluttered with hard-coded formatting, so we remove and generalize the formatting, simplifying the Description::detailed() method.
This commit is contained in:
John Shumway
2025-10-07 15:43:13 +00:00
parent f8838d7b38
commit fe5fbcbc64
2 changed files with 201 additions and 69 deletions

View File

@@ -6,6 +6,7 @@
#include <ck_tile/builder/conv_signature.hpp>
#include <ck_tile/builder/conv_traits.hpp>
#include <ck_tile/builder/tree_formatter.hpp>
namespace ck_tile::reflect {
@@ -53,49 +54,6 @@ struct CBlockTransferInfo
int n_wave_per_xdl;
};
// Opaque implmentation detail for CK convolution pipelines.
//
// This notation is opaque and should be replace by more general, descriptive data.
// TODO: Remove this implementation detail from ck_tile::reflect.
enum class PipelineVersion
{
V1,
V2,
V3,
V4,
V5
};
// Convert enums to string.
// TODO: Remove this once we hide the pipeline version from reflection.
constexpr std::string_view PipelineToString(PipelineVersion pipeline)
{
switch(pipeline)
{
case PipelineVersion::V1: return "V1";
case PipelineVersion::V2: return "V2";
case PipelineVersion::V3: return "V3";
case PipelineVersion::V4: return "V4";
case PipelineVersion::V5: return "V5";
default: return "Unknown";
}
}
// Convert CK pipeline version to reflect pipeline version
// TODO: Remove this once we hide the pipeline version from reflection.
constexpr PipelineVersion ConvertPipelineVersion(ck::BlockGemmPipelineVersion ck_version)
{
switch(ck_version)
{
case ck::BlockGemmPipelineVersion::v1: return PipelineVersion::V1;
case ck::BlockGemmPipelineVersion::v2: return PipelineVersion::V2;
case ck::BlockGemmPipelineVersion::v3: return PipelineVersion::V3;
case ck::BlockGemmPipelineVersion::v4: return PipelineVersion::V4;
case ck::BlockGemmPipelineVersion::v5: return PipelineVersion::V5;
}
return PipelineVersion::V3; // Fallback
}
// Algorithm information - groups all algorithm-related configuration
struct AlgorithmInfo
{
@@ -104,7 +62,7 @@ struct AlgorithmInfo
BlockTransferInfo a_transfer;
BlockTransferInfo b_transfer;
CBlockTransferInfo c_transfer;
PipelineVersion pipeline;
ck::BlockGemmPipelineVersion pipeline;
};
// Provides human-readable descriptions of ConvBuilder configurations.
@@ -126,41 +84,77 @@ struct Description
std::string detailed() const
{
std::ostringstream oss;
TreeFormatter tree;
// Root line - no tree formatting
oss << signature.spatial_dim << "D " << builder::ConvDirectionToString(signature.direction)
<< " Convolution Kernel\n";
oss << "├─ Signature\n";
oss << "│ ├─ Tensor Type: " << builder::DataTypeToString(signature.data_type) << "\n";
oss << "│ └─ Memory Layout: " << builder::LayoutToString(signature.layout) << "\n";
oss << "└─ Algorithm\n";
tree.writeLine(1, "Signature");
tree.writeLine(2, "Tensor Type: ", signature.data_type);
tree.writeLastLine(2, "Memory Layout: ", signature.layout);
tree.writeLastLine(1, "Algorithm");
// Compute Block section
oss << " ├─ Compute Block: " << algorithm.block.m << "×" << algorithm.block.n << "×"
<< algorithm.block.k << " submatrix (" << algorithm.block.block_size << " threads)\n";
tree.writeLine(2,
"Compute Block: ",
algorithm.block.m,
"×",
algorithm.block.n,
"×",
algorithm.block.k,
" submatrix (",
algorithm.block.block_size,
" threads)");
oss << " │ ├─ XDL Waves: " << algorithm.tuning.m_xdl_per_wave << "×"
<< algorithm.tuning.n_xdl_per_wave << " mapping ("
<< (algorithm.tuning.m_xdl_per_wave * algorithm.tuning.n_xdl_per_wave)
<< " waves total)\n";
oss << " └─ Tuning: ak1=" << algorithm.tuning.ak1 << ", bk1=" << algorithm.tuning.bk1
<< " (optimized for MI300 MFMA)\n";
tree.writeLine(3,
"XDL Waves: ",
algorithm.tuning.m_xdl_per_wave,
"×",
algorithm.tuning.n_xdl_per_wave,
" mapping (",
(algorithm.tuning.m_xdl_per_wave * algorithm.tuning.n_xdl_per_wave),
" waves total)");
tree.writeLastLine(3,
"Tuning: ak1=",
algorithm.tuning.ak1,
", bk1=",
algorithm.tuning.bk1,
" (optimized for MI300 MFMA)");
// Memory Layout section
oss << " ├─ Memory Layout:\n";
oss << " │ ├─ A-Transfer: " << algorithm.a_transfer.k0 << "×"
<< algorithm.a_transfer.m_or_n << "×" << algorithm.a_transfer.k1
<< " thread clusters (coalesced reads)\n";
oss << " ├─ B-Transfer: " << algorithm.b_transfer.k0 << "×"
<< algorithm.b_transfer.m_or_n << "×" << algorithm.b_transfer.k1
<< " thread clusters (broadcast-friendly)\n";
oss << " └─ C-Transfer: " << algorithm.c_transfer.m_block << "×"
<< algorithm.c_transfer.m_wave_per_xdl << "×" << algorithm.c_transfer.n_block << "×"
<< algorithm.c_transfer.n_wave_per_xdl << " clusters (efficient writeback)\n";
tree.writeLine(2, "Memory Layout:");
tree.writeLine(3,
"A-Transfer: ",
algorithm.a_transfer.k0,
"×",
algorithm.a_transfer.m_or_n,
"×",
algorithm.a_transfer.k1,
" thread clusters (coalesced reads)");
tree.writeLine(3,
"B-Transfer: ",
algorithm.b_transfer.k0,
"×",
algorithm.b_transfer.m_or_n,
"×",
algorithm.b_transfer.k1,
" thread clusters (broadcast-friendly)");
tree.writeLastLine(3,
"C-Transfer: ",
algorithm.c_transfer.m_block,
"×",
algorithm.c_transfer.m_wave_per_xdl,
"×",
algorithm.c_transfer.n_block,
"×",
algorithm.c_transfer.n_wave_per_xdl,
" clusters (efficient writeback)");
// Pipeline section
oss << " └─ Pipeline: " << PipelineToString(algorithm.pipeline);
tree.writeLastLine(2, "Pipeline: ", algorithm.pipeline);
return oss.str();
return oss.str() + tree.getString();
}
// Educational explanation of optimization choices
@@ -218,7 +212,7 @@ Description Describe()
Traits::c_block_transfer.thread_cluster_dims[2],
Traits::c_block_transfer.thread_cluster_dims[3]},
// pipeline
ConvertPipelineVersion(Traits::pipeline_version)}};
Traits::pipeline_version}};
}
} // namespace ck_tile::reflect

View File

@@ -0,0 +1,138 @@
#pragma once
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
#include <ck_tile/builder/conv_signature.hpp>
#include <ck_tile/builder/conv_traits.hpp>
namespace ck_tile::reflect {
// Convert CK block GEMM pipeline enums to string.
// TODO: Remove this once we hide the pipeline version from reflection.
constexpr std::string_view PipelineToString(ck::BlockGemmPipelineVersion pipeline)
{
switch(pipeline)
{
case ck::BlockGemmPipelineVersion::v1: return "V1";
case ck::BlockGemmPipelineVersion::v2: return "V2";
case ck::BlockGemmPipelineVersion::v3: return "V3";
case ck::BlockGemmPipelineVersion::v4: return "V4";
case ck::BlockGemmPipelineVersion::v5: return "V5";
default: return "Unknown";
}
}
// enum class PipelineVersion;
// // Forward declare PipelineToString (actual definition must be included separately)
// constexpr std::string_view PipelineToString(PipelineVersion);
// Helper class for formatting hierarchical tree structures with proper indentation
// and tree-drawing characters (├─, └─, │, etc.)
class TreeFormatter
{
public:
TreeFormatter() = default;
// Write a line at the specified indentation level (branch continues after this)
template <typename... Args>
void writeLine(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
}
// Write the last line at the specified indentation level (branch ends)
template <typename... Args>
void writeLastLine(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
}
// Get the formatted string (removes trailing newline if present)
std::string getString() const
{
std::string result = oss_.str();
if(!result.empty() && result.back() == '\n')
{
result.pop_back();
}
return result;
}
private:
std::ostringstream oss_;
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
// Helper to format individual arguments with automatic type conversion
template <typename T>
void formatArg(const T& arg)
{
if constexpr(std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>,
builder::DataType>)
{
oss_ << builder::DataTypeToString(arg);
}
else if constexpr(std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>,
builder::ConvDirection>)
{
oss_ << builder::ConvDirectionToString(arg);
}
else if constexpr(std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>,
builder::GroupConvLayout>)
{
oss_ << builder::LayoutToString(arg);
}
else if constexpr(std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>,
ck::BlockGemmPipelineVersion>)
{
oss_ << PipelineToString(arg);
}
else
{
oss_ << arg; // Default: just stream it
}
}
// Implementation of line writing with tree symbols
template <typename... Args>
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
{
// Ensure we have enough tracking space
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
{
is_last_at_level_.resize(indent_level + 1, false);
// Level 0 (root) should always be treated as "last" since it has no tree symbols
if(is_last_at_level_.size() > 0)
{
is_last_at_level_[0] = true;
}
}
// Draw the tree structure
// Start from level 1 (skip level 0 which is the root with no symbols)
for(int i = 1; i < indent_level; ++i)
{
// For all parent levels, draw vertical line or space based on whether they ended
oss_ << (is_last_at_level_[i] ? " " : "");
}
// Draw the branch symbol for the current level
if(indent_level > 0)
{
oss_ << (is_last ? "└─ " : "├─ ");
}
// Write the content using fold expression with formatArg
(formatArg(std::forward<Args>(args)), ...);
oss_ << '\n';
// Update tracking for this level AFTER writing the line
// This ensures future lines at deeper levels know if this level ended
is_last_at_level_[indent_level] = is_last;
}
};
} // namespace ck_tile::reflect