mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Extract TreeFormatter from Description class.
Description class was cluttered with hard-coded formatting, so we remove and generalize the formatting, simplifying the Description::detailed() method.
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <ck_tile/builder/conv_signature.hpp>
|
||||
#include <ck_tile/builder/conv_traits.hpp>
|
||||
#include <ck_tile/builder/tree_formatter.hpp>
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
@@ -53,49 +54,6 @@ struct CBlockTransferInfo
|
||||
int n_wave_per_xdl;
|
||||
};
|
||||
|
||||
// Opaque implmentation detail for CK convolution pipelines.
|
||||
//
|
||||
// This notation is opaque and should be replace by more general, descriptive data.
|
||||
// TODO: Remove this implementation detail from ck_tile::reflect.
|
||||
enum class PipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
// Convert enums to string.
|
||||
// TODO: Remove this once we hide the pipeline version from reflection.
|
||||
constexpr std::string_view PipelineToString(PipelineVersion pipeline)
|
||||
{
|
||||
switch(pipeline)
|
||||
{
|
||||
case PipelineVersion::V1: return "V1";
|
||||
case PipelineVersion::V2: return "V2";
|
||||
case PipelineVersion::V3: return "V3";
|
||||
case PipelineVersion::V4: return "V4";
|
||||
case PipelineVersion::V5: return "V5";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert CK pipeline version to reflect pipeline version
|
||||
// TODO: Remove this once we hide the pipeline version from reflection.
|
||||
constexpr PipelineVersion ConvertPipelineVersion(ck::BlockGemmPipelineVersion ck_version)
|
||||
{
|
||||
switch(ck_version)
|
||||
{
|
||||
case ck::BlockGemmPipelineVersion::v1: return PipelineVersion::V1;
|
||||
case ck::BlockGemmPipelineVersion::v2: return PipelineVersion::V2;
|
||||
case ck::BlockGemmPipelineVersion::v3: return PipelineVersion::V3;
|
||||
case ck::BlockGemmPipelineVersion::v4: return PipelineVersion::V4;
|
||||
case ck::BlockGemmPipelineVersion::v5: return PipelineVersion::V5;
|
||||
}
|
||||
return PipelineVersion::V3; // Fallback
|
||||
}
|
||||
|
||||
// Algorithm information - groups all algorithm-related configuration
|
||||
struct AlgorithmInfo
|
||||
{
|
||||
@@ -104,7 +62,7 @@ struct AlgorithmInfo
|
||||
BlockTransferInfo a_transfer;
|
||||
BlockTransferInfo b_transfer;
|
||||
CBlockTransferInfo c_transfer;
|
||||
PipelineVersion pipeline;
|
||||
ck::BlockGemmPipelineVersion pipeline;
|
||||
};
|
||||
|
||||
// Provides human-readable descriptions of ConvBuilder configurations.
|
||||
@@ -126,41 +84,77 @@ struct Description
|
||||
std::string detailed() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
TreeFormatter tree;
|
||||
|
||||
// Root line - no tree formatting
|
||||
oss << signature.spatial_dim << "D " << builder::ConvDirectionToString(signature.direction)
|
||||
<< " Convolution Kernel\n";
|
||||
oss << "├─ Signature\n";
|
||||
oss << "│ ├─ Tensor Type: " << builder::DataTypeToString(signature.data_type) << "\n";
|
||||
oss << "│ └─ Memory Layout: " << builder::LayoutToString(signature.layout) << "\n";
|
||||
|
||||
oss << "└─ Algorithm\n";
|
||||
tree.writeLine(1, "Signature");
|
||||
tree.writeLine(2, "Tensor Type: ", signature.data_type);
|
||||
tree.writeLastLine(2, "Memory Layout: ", signature.layout);
|
||||
|
||||
tree.writeLastLine(1, "Algorithm");
|
||||
// Compute Block section
|
||||
oss << " ├─ Compute Block: " << algorithm.block.m << "×" << algorithm.block.n << "×"
|
||||
<< algorithm.block.k << " submatrix (" << algorithm.block.block_size << " threads)\n";
|
||||
tree.writeLine(2,
|
||||
"Compute Block: ",
|
||||
algorithm.block.m,
|
||||
"×",
|
||||
algorithm.block.n,
|
||||
"×",
|
||||
algorithm.block.k,
|
||||
" submatrix (",
|
||||
algorithm.block.block_size,
|
||||
" threads)");
|
||||
|
||||
oss << " │ ├─ XDL Waves: " << algorithm.tuning.m_xdl_per_wave << "×"
|
||||
<< algorithm.tuning.n_xdl_per_wave << " mapping ("
|
||||
<< (algorithm.tuning.m_xdl_per_wave * algorithm.tuning.n_xdl_per_wave)
|
||||
<< " waves total)\n";
|
||||
oss << " │ └─ Tuning: ak1=" << algorithm.tuning.ak1 << ", bk1=" << algorithm.tuning.bk1
|
||||
<< " (optimized for MI300 MFMA)\n";
|
||||
tree.writeLine(3,
|
||||
"XDL Waves: ",
|
||||
algorithm.tuning.m_xdl_per_wave,
|
||||
"×",
|
||||
algorithm.tuning.n_xdl_per_wave,
|
||||
" mapping (",
|
||||
(algorithm.tuning.m_xdl_per_wave * algorithm.tuning.n_xdl_per_wave),
|
||||
" waves total)");
|
||||
tree.writeLastLine(3,
|
||||
"Tuning: ak1=",
|
||||
algorithm.tuning.ak1,
|
||||
", bk1=",
|
||||
algorithm.tuning.bk1,
|
||||
" (optimized for MI300 MFMA)");
|
||||
|
||||
// Memory Layout section
|
||||
oss << " ├─ Memory Layout:\n";
|
||||
oss << " │ ├─ A-Transfer: " << algorithm.a_transfer.k0 << "×"
|
||||
<< algorithm.a_transfer.m_or_n << "×" << algorithm.a_transfer.k1
|
||||
<< " thread clusters (coalesced reads)\n";
|
||||
oss << " │ ├─ B-Transfer: " << algorithm.b_transfer.k0 << "×"
|
||||
<< algorithm.b_transfer.m_or_n << "×" << algorithm.b_transfer.k1
|
||||
<< " thread clusters (broadcast-friendly)\n";
|
||||
oss << " │ └─ C-Transfer: " << algorithm.c_transfer.m_block << "×"
|
||||
<< algorithm.c_transfer.m_wave_per_xdl << "×" << algorithm.c_transfer.n_block << "×"
|
||||
<< algorithm.c_transfer.n_wave_per_xdl << " clusters (efficient writeback)\n";
|
||||
tree.writeLine(2, "Memory Layout:");
|
||||
tree.writeLine(3,
|
||||
"A-Transfer: ",
|
||||
algorithm.a_transfer.k0,
|
||||
"×",
|
||||
algorithm.a_transfer.m_or_n,
|
||||
"×",
|
||||
algorithm.a_transfer.k1,
|
||||
" thread clusters (coalesced reads)");
|
||||
tree.writeLine(3,
|
||||
"B-Transfer: ",
|
||||
algorithm.b_transfer.k0,
|
||||
"×",
|
||||
algorithm.b_transfer.m_or_n,
|
||||
"×",
|
||||
algorithm.b_transfer.k1,
|
||||
" thread clusters (broadcast-friendly)");
|
||||
tree.writeLastLine(3,
|
||||
"C-Transfer: ",
|
||||
algorithm.c_transfer.m_block,
|
||||
"×",
|
||||
algorithm.c_transfer.m_wave_per_xdl,
|
||||
"×",
|
||||
algorithm.c_transfer.n_block,
|
||||
"×",
|
||||
algorithm.c_transfer.n_wave_per_xdl,
|
||||
" clusters (efficient writeback)");
|
||||
|
||||
// Pipeline section
|
||||
oss << " └─ Pipeline: " << PipelineToString(algorithm.pipeline);
|
||||
tree.writeLastLine(2, "Pipeline: ", algorithm.pipeline);
|
||||
|
||||
return oss.str();
|
||||
return oss.str() + tree.getString();
|
||||
}
|
||||
|
||||
// Educational explanation of optimization choices
|
||||
@@ -218,7 +212,7 @@ Description Describe()
|
||||
Traits::c_block_transfer.thread_cluster_dims[2],
|
||||
Traits::c_block_transfer.thread_cluster_dims[3]},
|
||||
// pipeline
|
||||
ConvertPipelineVersion(Traits::pipeline_version)}};
|
||||
Traits::pipeline_version}};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
|
||||
138
experimental/builder/include/ck_tile/builder/tree_formatter.hpp
Normal file
138
experimental/builder/include/ck_tile/builder/tree_formatter.hpp
Normal file
@@ -0,0 +1,138 @@
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include <ck_tile/builder/conv_signature.hpp>
|
||||
#include <ck_tile/builder/conv_traits.hpp>
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Convert CK block GEMM pipeline enums to string.
|
||||
// TODO: Remove this once we hide the pipeline version from reflection.
|
||||
constexpr std::string_view PipelineToString(ck::BlockGemmPipelineVersion pipeline)
|
||||
{
|
||||
switch(pipeline)
|
||||
{
|
||||
case ck::BlockGemmPipelineVersion::v1: return "V1";
|
||||
case ck::BlockGemmPipelineVersion::v2: return "V2";
|
||||
case ck::BlockGemmPipelineVersion::v3: return "V3";
|
||||
case ck::BlockGemmPipelineVersion::v4: return "V4";
|
||||
case ck::BlockGemmPipelineVersion::v5: return "V5";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
// enum class PipelineVersion;
|
||||
|
||||
// // Forward declare PipelineToString (actual definition must be included separately)
|
||||
// constexpr std::string_view PipelineToString(PipelineVersion);
|
||||
|
||||
// Helper class for formatting hierarchical tree structures with proper indentation
|
||||
// and tree-drawing characters (├─, └─, │, etc.)
|
||||
class TreeFormatter
|
||||
{
|
||||
public:
|
||||
TreeFormatter() = default;
|
||||
|
||||
// Write a line at the specified indentation level (branch continues after this)
|
||||
template <typename... Args>
|
||||
void writeLine(int indent_level, Args&&... args)
|
||||
{
|
||||
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Write the last line at the specified indentation level (branch ends)
|
||||
template <typename... Args>
|
||||
void writeLastLine(int indent_level, Args&&... args)
|
||||
{
|
||||
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Get the formatted string (removes trailing newline if present)
|
||||
std::string getString() const
|
||||
{
|
||||
std::string result = oss_.str();
|
||||
if(!result.empty() && result.back() == '\n')
|
||||
{
|
||||
result.pop_back();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::ostringstream oss_;
|
||||
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
|
||||
|
||||
// Helper to format individual arguments with automatic type conversion
|
||||
template <typename T>
|
||||
void formatArg(const T& arg)
|
||||
{
|
||||
if constexpr(std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>,
|
||||
builder::DataType>)
|
||||
{
|
||||
oss_ << builder::DataTypeToString(arg);
|
||||
}
|
||||
else if constexpr(std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>,
|
||||
builder::ConvDirection>)
|
||||
{
|
||||
oss_ << builder::ConvDirectionToString(arg);
|
||||
}
|
||||
else if constexpr(std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>,
|
||||
builder::GroupConvLayout>)
|
||||
{
|
||||
oss_ << builder::LayoutToString(arg);
|
||||
}
|
||||
else if constexpr(std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>,
|
||||
ck::BlockGemmPipelineVersion>)
|
||||
{
|
||||
oss_ << PipelineToString(arg);
|
||||
}
|
||||
else
|
||||
{
|
||||
oss_ << arg; // Default: just stream it
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation of line writing with tree symbols
|
||||
template <typename... Args>
|
||||
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
|
||||
{
|
||||
// Ensure we have enough tracking space
|
||||
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
|
||||
{
|
||||
is_last_at_level_.resize(indent_level + 1, false);
|
||||
// Level 0 (root) should always be treated as "last" since it has no tree symbols
|
||||
if(is_last_at_level_.size() > 0)
|
||||
{
|
||||
is_last_at_level_[0] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Draw the tree structure
|
||||
// Start from level 1 (skip level 0 which is the root with no symbols)
|
||||
for(int i = 1; i < indent_level; ++i)
|
||||
{
|
||||
// For all parent levels, draw vertical line or space based on whether they ended
|
||||
oss_ << (is_last_at_level_[i] ? " " : "│ ");
|
||||
}
|
||||
|
||||
// Draw the branch symbol for the current level
|
||||
if(indent_level > 0)
|
||||
{
|
||||
oss_ << (is_last ? "└─ " : "├─ ");
|
||||
}
|
||||
|
||||
// Write the content using fold expression with formatArg
|
||||
(formatArg(std::forward<Args>(args)), ...);
|
||||
|
||||
oss_ << '\n';
|
||||
|
||||
// Update tracking for this level AFTER writing the line
|
||||
// This ensures future lines at deeper levels know if this level ended
|
||||
is_last_at_level_[indent_level] = is_last;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
Reference in New Issue
Block a user