From 4247f7f4d41acfededa650855055753be250fb0e Mon Sep 17 00:00:00 2001 From: John Shumway Date: Wed, 8 Oct 2025 01:16:15 +0000 Subject: [PATCH] Clean up description and tree formatter. --- .../ck_tile/builder/conv_description.hpp | 190 +++++++++--------- .../ck_tile/builder/tree_formatter.hpp | 26 ++- 2 files changed, 114 insertions(+), 102 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_description.hpp b/experimental/builder/include/ck_tile/builder/conv_description.hpp index 12a122d66d..b58c9eda5e 100644 --- a/experimental/builder/include/ck_tile/builder/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_description.hpp @@ -10,7 +10,6 @@ namespace ck_tile::reflect { -// Decoupled structs in the reflect namespace for runtime storage struct SignatureInfo { int spatial_dim; @@ -83,78 +82,73 @@ struct Description // Detailed hierarchical description std::string detailed() const { - std::ostringstream oss; - TreeFormatter tree; + TreeFormatter f; + f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel"); + f.writeLine(1, "Signature"); + f.writeLine(2, "Tensor Type: ", signature.data_type); + f.writeLast(2, "Memory Layout: ", signature.layout); - // Root line - no tree formatting - oss << signature.spatial_dim << "D " << builder::ConvDirectionToString(signature.direction) - << " Convolution Kernel\n"; - - tree.writeLine(1, "Signature"); - tree.writeLine(2, "Tensor Type: ", signature.data_type); - tree.writeLastLine(2, "Memory Layout: ", signature.layout); - - tree.writeLastLine(1, "Algorithm"); + f.writeLast(1, "Algorithm"); // Compute Block section - tree.writeLine(2, - "Compute Block: ", - algorithm.block.m, - "×", - algorithm.block.n, - "×", - algorithm.block.k, - " submatrix (", - algorithm.block.block_size, - " threads)"); + f.writeLine(2, + "Compute Block: ", + algorithm.block.m, + "×", + algorithm.block.n, + "×", + algorithm.block.k, + " submatrix (", + algorithm.block.block_size, + " threads)"); - tree.writeLine(3, - "XDL Waves: ", - algorithm.tuning.m_xdl_per_wave, - "×", - algorithm.tuning.n_xdl_per_wave, - " mapping (", - (algorithm.tuning.m_xdl_per_wave * algorithm.tuning.n_xdl_per_wave), - " waves total)"); - tree.writeLastLine(3, - "Tuning: ak1=", - algorithm.tuning.ak1, - ", bk1=", - algorithm.tuning.bk1, - " (optimized for MI300 MFMA)"); + f.writeLine(3, + "XDL Waves: ", + algorithm.tuning.m_xdl_per_wave, + "×", + algorithm.tuning.n_xdl_per_wave, + " mapping (", + (algorithm.tuning.m_xdl_per_wave * algorithm.tuning.n_xdl_per_wave), + " waves total)"); + f.writeLast(3, + "Tuning: ak1=", + algorithm.tuning.ak1, + ", bk1=", + algorithm.tuning.bk1, + " (optimized for MI300 MFMA)"); // Memory Layout section - tree.writeLine(2, "Memory Layout:"); - tree.writeLine(3, - "A-Transfer: ", - algorithm.a_transfer.k0, - "×", - algorithm.a_transfer.m_or_n, - "×", - algorithm.a_transfer.k1, - " thread clusters (coalesced reads)"); - tree.writeLine(3, - "B-Transfer: ", - algorithm.b_transfer.k0, - "×", - algorithm.b_transfer.m_or_n, - "×", - algorithm.b_transfer.k1, - " thread clusters (broadcast-friendly)"); - tree.writeLastLine(3, - "C-Transfer: ", - algorithm.c_transfer.m_block, - "×", - algorithm.c_transfer.m_wave_per_xdl, - "×", - algorithm.c_transfer.n_block, - "×", - algorithm.c_transfer.n_wave_per_xdl, - " clusters (efficient writeback)"); + f.writeLine(2, "Memory Layout:"); + f.writeLine(3, + "A-Transfer: ", + algorithm.a_transfer.k0, + "×", + algorithm.a_transfer.m_or_n, + "×", + algorithm.a_transfer.k1, + " thread clusters (coalesced reads)"); + f.writeLine(3, + "B-Transfer: ", + algorithm.b_transfer.k0, + "×", + algorithm.b_transfer.m_or_n, + "×", + algorithm.b_transfer.k1, + " thread clusters (broadcast-friendly)"); + f.writeLast(3, + "C-Transfer: ", + algorithm.c_transfer.m_block, + "×", + algorithm.c_transfer.m_wave_per_xdl, + "×", + algorithm.c_transfer.n_block, + "×", + algorithm.c_transfer.n_wave_per_xdl, + " clusters (efficient writeback)"); // Pipeline section - tree.writeLastLine(2, "Pipeline: ", algorithm.pipeline); + f.writeLast(2, "Pipeline: ", algorithm.pipeline); - return oss.str() + tree.getString(); + return f.getString(); } // Educational explanation of optimization choices @@ -181,38 +175,40 @@ Description Describe() using Traits = ConvTraits; return Description{ - // signature - SignatureInfo{Traits::spatial_dim, Traits::direction, Traits::layout, Traits::data_type}, - // algorithm - AlgorithmInfo{// block - BlockInfo{Traits::block.block_size, - Traits::block.per_block.m, - Traits::block.per_block.n, - Traits::block.per_block.k}, - // tuning - TuningInfo{Traits::tuning.ak1, - Traits::tuning.bk1, - Traits::tuning.m_per_xdl, - Traits::tuning.n_per_dxl, - Traits::tuning.m_xdl_per_wave, - Traits::tuning.n_xdl_per_wave}, - // a_transfer - BlockTransferInfo{Traits::a_block_transfer.thread_cluster_dims[0], - Traits::a_block_transfer.thread_cluster_dims[1], - Traits::a_block_transfer.thread_cluster_dims[2]}, - // b_transfer - BlockTransferInfo{Traits::b_block_transfer.thread_cluster_dims[0], - Traits::b_block_transfer.thread_cluster_dims[1], - Traits::b_block_transfer.thread_cluster_dims[2]}, - // c_transfer - CBlockTransferInfo{Traits::c_block_transfer.m_xdl_per_wave_per_shuffle, - Traits::c_block_transfer.n_xdl_per_wave_per_shuffle, - Traits::c_block_transfer.thread_cluster_dims[0], - Traits::c_block_transfer.thread_cluster_dims[1], - Traits::c_block_transfer.thread_cluster_dims[2], - Traits::c_block_transfer.thread_cluster_dims[3]}, - // pipeline - Traits::pipeline_version}}; + .signature = SignatureInfo{.spatial_dim = Traits::spatial_dim, + .direction = Traits::direction, + .layout = Traits::layout, + .data_type = Traits::data_type}, + .algorithm = AlgorithmInfo{ + .block = BlockInfo{.block_size = Traits::block.block_size, + .m = Traits::block.per_block.m, + .n = Traits::block.per_block.n, + .k = Traits::block.per_block.k}, + .tuning = TuningInfo{.ak1 = Traits::tuning.ak1, + .bk1 = Traits::tuning.bk1, + .m_per_xdl = Traits::tuning.m_per_xdl, + .n_per_xdl = Traits::tuning.n_per_dxl, + .m_xdl_per_wave = Traits::tuning.m_xdl_per_wave, + .n_xdl_per_wave = Traits::tuning.n_xdl_per_wave}, + .a_transfer = + BlockTransferInfo{.k0 = Traits::a_block_transfer.thread_cluster_dims[0], + .m_or_n = Traits::a_block_transfer.thread_cluster_dims[1], + .k1 = Traits::a_block_transfer.thread_cluster_dims[2]}, + .b_transfer = + BlockTransferInfo{.k0 = Traits::b_block_transfer.thread_cluster_dims[0], + .m_or_n = Traits::b_block_transfer.thread_cluster_dims[1], + .k1 = Traits::b_block_transfer.thread_cluster_dims[2]}, + .c_transfer = + CBlockTransferInfo{ + .m_xdl_per_wave_per_shuffle = + Traits::c_block_transfer.m_xdl_per_wave_per_shuffle, + .n_xdl_per_wave_per_shuffle = + Traits::c_block_transfer.n_xdl_per_wave_per_shuffle, + .m_block = Traits::c_block_transfer.thread_cluster_dims[0], + .m_wave_per_xdl = Traits::c_block_transfer.thread_cluster_dims[1], + .n_block = Traits::c_block_transfer.thread_cluster_dims[2], + .n_wave_per_xdl = Traits::c_block_transfer.thread_cluster_dims[3]}, + .pipeline = Traits::pipeline_version}}; } } // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/tree_formatter.hpp b/experimental/builder/include/ck_tile/builder/tree_formatter.hpp index ae52afce5e..fd7460d1c7 100644 --- a/experimental/builder/include/ck_tile/builder/tree_formatter.hpp +++ b/experimental/builder/include/ck_tile/builder/tree_formatter.hpp @@ -24,13 +24,29 @@ constexpr std::string_view PipelineToString(ck::BlockGemmPipelineVersion pipelin default: return "Unknown"; } } -// enum class PipelineVersion; - -// // Forward declare PipelineToString (actual definition must be included separately) -// constexpr std::string_view PipelineToString(PipelineVersion); // Helper class for formatting hierarchical tree structures with proper indentation // and tree-drawing characters (├─, └─, │, etc.) +// +// Example Usage: +// +// TreeFormatter f; +// f.writeLine(0, "Root"); +// f.writeLine(1, "Branch 1"); +// f.writeLine(2, "Item 1a"); +// f.writeLast(2, "Item 1b"); +// f.writeLast(1, "Branch 2"); +// f.writeLast(2, "Item 2a"); +// std::cout << f.getString() << "\n"; +// +// Generated Output: +// +// Root +// ├─ Branch 1 +// │ ├─ Item 1a +// │ └─ Item 1b +// └─ Branch 2 +// └─ Item 2a class TreeFormatter { public: @@ -45,7 +61,7 @@ class TreeFormatter // Write the last line at the specified indentation level (branch ends) template - void writeLastLine(int indent_level, Args&&... args) + void writeLast(int indent_level, Args&&... args) { writeLineImpl(indent_level, true, std::forward(args)...); }