[CK_BUILDER] Convolution description (#3163)

* Add DirectLoad tparam & clean up headers.

* Add convolution traits.

* Update inline documentation.

* Add more convolution specialization and gemm padding types.

* Add additional helper functions & more tests to conv traits.

* Fix tests cmake file.

* Add case insensitive string comparison

* Fix function name overlapping with variable name.

* Unify pipeline version and scheduler enums.

* Fix includes.

* Update test conv traits with unified enums.

* Update concepts etc with update unified enum

* Fix ckb conv fwd test - unified enum usage.

* Dump changes.

* Add ostream overloads for all enum classes.

* Update detailed() function in ConvDescription

* Fix handling union based conv direction.

* Add test & update conv description.

* Refine tree view.

* Update copyrights

* Fix merge artifacts

* Update detailed tree conv description

* Fix clang-format
This commit is contained in:
Adam Osewski
2025-11-06 15:46:26 +01:00
committed by GitHub
parent 2234ff830b
commit 18e083003f
9 changed files with 842 additions and 64 deletions

View File

@@ -78,66 +78,4 @@ struct UnsupportedEnumValue
{
};
// Helper functions to convert enums to strings
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
{
switch(dir)
{
case ConvDirection::FORWARD: return "Forward";
case ConvDirection::BACKWARD_DATA: return "Backward Data";
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
default: return "Unknown";
}
}
constexpr std::string_view DataTypeToString(DataType dt)
{
switch(dt)
{
case DataType::FP16: return "FP16";
case DataType::FP32: return "FP32";
case DataType::BF16: return "BF16";
case DataType::FP8: return "FP8";
case DataType::I8: return "I8";
case DataType::U8: return "U8";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
{
switch(layout)
{
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
{
switch(layout)
{
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
{
switch(layout)
{
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
default: return "Unknown";
}
}
} // namespace ck_tile::builder

View File

@@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
@@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward =
// Predicate for DeviceGroupedConvBwdWeight operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
@@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight =
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);

View File

@@ -0,0 +1,268 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <variant>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/tree_formatter.hpp>
/// @file conv_description.hpp
/// @brief Provides human-readable descriptions of ConvBuilder configurations
namespace ck_tile::reflect::conv {
struct ConvSignatureInfo
{
int spatial_dim;
builder::ConvDirection direction;
std::variant<builder::GroupConvLayout1D, builder::GroupConvLayout2D, builder::GroupConvLayout3D>
layout;
builder::DataType data_type;
builder::ElementwiseOperation input_element_op;
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;
};
// Algorithm information - groups all algorithm-related configuration
struct GemmAlgorithmInfo
{
int thread_block_size;
DataTileInfo tile_dims;
WarpGemmParams warp_gemm;
InputTileTransferInfo a_tile_transfer;
InputTileTransferInfo b_tile_transfer;
OutputTileTransferInfo c_tile_transfer;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
std::variant<builder::ConvFwdSpecialization,
builder::ConvBwdDataSpecialization,
builder::ConvBwdWeightSpecialization>
conv_specialization;
builder::GemmPadding padding;
};
// Provides human-readable descriptions of ConvBuilder configurations.
struct ConvDescription
{
ConvSignatureInfo signature;
GemmAlgorithmInfo algorithm;
// Brief one-line summary
std::string brief() const
{
std::ostringstream oss;
oss << signature.spatial_dim << "D " << signature.direction << " convolution";
return oss.str();
}
// Detailed hierarchical description
std::string detailed() const
{
TreeFormatter f;
f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel");
f.writeLine(1, "Signature");
f.writeLine(2, "Tensor Type: ", signature.data_type);
f.writeLine(2, "Memory Layout: ", signature.layout);
f.writeLine(2, "Input elementwise operation: ", signature.input_element_op);
f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", signature.output_element_op);
f.writeLine(1, "Algorithm");
// Compute Block section
f.writeLine(2, "Thread block size: ", algorithm.thread_block_size);
f.writeLine(2,
"Data tile size: ",
algorithm.tile_dims.m,
"×",
algorithm.tile_dims.n,
"×",
algorithm.tile_dims.k);
f.writeLine(2, "Gemm padding: ", algorithm.padding);
f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization);
// Pipeline section
f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version);
f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler);
f.writeLine(2, "Warp Gemm parameters: ");
f.writeLine(
3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n);
f.writeLast(3,
"Number of warp gemm iterations: ",
algorithm.warp_gemm.m_iter,
"×",
algorithm.warp_gemm.n_iter);
// Memory Access section
f.writeLine(2, "Memory access:");
f.writeLine(3, "A Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm.a_tile_transfer.tile_dimensions.k0,
"×",
algorithm.a_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm.a_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm.a_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm.a_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm.a_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm.a_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLine(3, "B Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm.b_tile_transfer.tile_dimensions.k0,
"×",
algorithm.b_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm.b_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm.b_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm.b_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm.b_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm.b_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(3, "C Tile transfer: ");
f.writeLine(4,
"Data shuffle (number of gemm instructions per iteration): ",
algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
"×",
algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
f.writeLine(4,
"Spatial thread distribution used to store data: ",
algorithm.c_tile_transfer.thread_cluster_dims[0],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[1],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[2],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[3]);
f.writeLast(4,
"Vector access (GMEM write) instruction size: ",
algorithm.c_tile_transfer.scalar_per_vector);
f.writeLast(2);
f.writeLast(1);
return f.getString();
}
// Educational explanation of optimization choices
std::string explain() const
{
std::ostringstream oss;
// Placeholder for future implementation
return oss.str();
}
// Performance characteristics and use case guidance
std::string suggest() const
{
std::ostringstream oss;
// Placeholder for future implementation
return oss.str();
}
};
// Helper concept to detect if a type has InstanceTraits specialization
template <typename T>
concept HasInstanceTraits = requires { typename InstanceTraits<T>; };
// Helper concept to detect ConvBuilder types
template <typename T>
concept IsConvBuilder = requires {
typename T::Factory;
typename T::Instance;
};
// Primary factory function: Create ConvDescription from Instance type directly
template <typename Instance>
requires HasInstanceTraits<Instance>
ConvDescription Describe()
{
using Traits = ConvTraits<Instance>;
return ConvDescription{
.signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim,
.direction = Traits::direction,
.layout = Traits::layout,
.data_type = Traits::data_type,
.input_element_op = Traits::input_element_op,
.weight_element_op = Traits::weight_element_op,
.output_element_op = Traits::output_element_op},
.algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size,
.tile_dims = Traits::tile_dims,
.warp_gemm = Traits::warp_gemm,
.a_tile_transfer = Traits::a_tile_transfer,
.b_tile_transfer = Traits::b_tile_transfer,
.c_tile_transfer = Traits::c_tile_transfer,
.pipeline_version = Traits::pipeline_version,
.pipeline_scheduler = Traits::pipeline_scheduler,
.conv_specialization = Traits::conv_specialization,
.padding = Traits::gemm_padding}};
}
// Backward compatibility: Create ConvDescription from Builder type
template <typename Builder>
requires IsConvBuilder<Builder> && (!HasInstanceTraits<Builder>)
ConvDescription Describe()
{
// Delegate to Instance-based version
using Instance = typename Builder::Instance;
return Describe<Instance>();
}
} // namespace ck_tile::reflect::conv

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -13,7 +13,10 @@
#include <string_view>
#include <sstream>
#include <type_traits>
#include <climits>
#include <limits.h>
#include <cmath>
#include <ostream>
#include <iostream>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>

View File

@@ -0,0 +1,106 @@
#pragma once
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
namespace ck_tile::reflect {
// Helper class for formatting hierarchical tree structures with proper indentation
// and tree-drawing characters (├─, └─, │, etc.)
//
// Example Usage:
//
// TreeFormatter f;
// f.writeLine(0, "Root");
// f.writeLine(1, "Branch 1");
// f.writeLine(2, "Item 1a");
// f.writeLast(2, "Item 1b");
// f.writeLast(1, "Branch 2");
// f.writeLast(2, "Item 2a");
// std::cout << f.getString() << "\n";
//
// Generated Output:
//
// Root
// ├─ Branch 1
// │ ├─ Item 1a
// │ └─ Item 1b
// └─ Branch 2
// └─ Item 2a
class TreeFormatter
{
public:
TreeFormatter() = default;
// Write a line at the specified indentation level (branch continues after this)
template <typename... Args>
void writeLine(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
}
// Write the last line at the specified indentation level (branch ends)
template <typename... Args>
void writeLast(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
}
// Get the formatted string (removes trailing newline if present)
std::string getString() const
{
std::string result = oss_.str();
if(!result.empty() && result.back() == '\n')
{
result.pop_back();
}
return result;
}
private:
std::ostringstream oss_;
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
// Implementation of line writing with tree symbols
template <typename... Args>
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
{
// Ensure we have enough tracking space
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
{
is_last_at_level_.resize(indent_level + 1, false);
// Level 0 (root) should always be treated as "last" since it has no tree symbols
if(is_last_at_level_.size() > 0)
{
is_last_at_level_[0] = true;
}
}
// Draw the tree structure
// Start from level 1 (skip level 0 which is the root with no symbols)
for(int i = 1; i < indent_level; ++i)
{
// For all parent levels, draw vertical line or space based on whether they ended
oss_ << (is_last_at_level_[i] ? " " : "");
}
// Draw the branch symbol for the current level
if(indent_level > 0)
{
oss_ << (is_last ? "└─ " : "├─ ");
}
// Write the content using fold expression with direct stream insertion
((oss_ << std::forward<Args>(args)), ...);
oss_ << '\n';
// Update tracking for this level AFTER writing the line
// This ensures future lines at deeper levels know if this level ended
is_last_at_level_[indent_level] = is_last;
}
};
} // namespace ck_tile::reflect

View File

@@ -3,6 +3,10 @@
#pragma once
#include <ostream>
#include <string_view>
#include <variant>
namespace ck_tile::builder {
enum class DataType
@@ -215,4 +219,275 @@ enum class PipelineScheduler
INTERWAVE
};
// ostream operator overloads for enum classes
inline std::ostream& operator<<(std::ostream& os, DataType dt)
{
using enum DataType;
switch(dt)
{
case FP16: return os << "FP16";
case FP32: return os << "FP32";
case BF16: return os << "BF16";
case FP8: return os << "FP8";
case I8: return os << "I8";
case U8: return os << "U8";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
{
using enum ConvDirection;
switch(dir)
{
case FORWARD: return os << "Forward";
case BACKWARD_DATA: return os << "Backward Data";
case BACKWARD_WEIGHT: return os << "Backward Weight";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout)
{
using enum GroupConvLayout1D;
switch(layout)
{
case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK";
case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK";
case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW";
case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout)
{
using enum GroupConvLayout2D;
switch(layout)
{
case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK";
case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK";
case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW";
case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
{
using enum GroupConvLayout3D;
switch(layout)
{
case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK";
case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK";
case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW";
case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op)
{
using enum FwdGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK:
return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle:
return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle:
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3:
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor:
return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op)
{
using enum BwdDataGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD";
case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle:
return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1:
return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op)
{
using enum BwdWeightGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight";
case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl";
case DeviceGroupedConvBwdWeight_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3:
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
case DeviceGroupedConvBwdWeight_Wmma_CShuffle:
return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD";
case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
{
using enum ElementwiseOperation;
switch(op)
{
case BIAS: return os << "BIAS";
case BIAS_CLAMP: return os << "BIAS_CLAMP";
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
case BILINEAR: return os << "BILINEAR";
case CLAMP: return os << "CLAMP";
case SCALE: return os << "SCALE";
case PASS_THROUGH: return os << "PASS_THROUGH";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
{
using enum PipelineVersion;
switch(ver)
{
case V1: return os << "V1";
case V2: return os << "V2";
case V3: return os << "V3";
case V4: return os << "V4";
case V5: return os << "V5";
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
{
using enum GemmSpecialization;
switch(spec)
{
case Default: return os << "Default";
case MPadding: return os << "MPadding";
case NPadding: return os << "NPadding";
case KPadding: return os << "KPadding";
case MNPadding: return os << "MNPadding";
case MKPadding: return os << "MKPadding";
case NKPadding: return os << "NKPadding";
case MNKPadding: return os << "MNKPadding";
case OPadding: return os << "OPadding";
case MOPadding: return os << "MOPadding";
case NOPadding: return os << "NOPadding";
case KOPadding: return os << "KOPadding";
case MNOPadding: return os << "MNOPadding";
case MKOPadding: return os << "MKOPadding";
case NKOPadding: return os << "NKOPadding";
case MNKOPadding: return os << "MNKOPadding";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
{
using enum ConvFwdSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_3x3: return os << "FILTER_3x3";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
{
using enum ConvBwdDataSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
{
using enum ConvBwdWeightSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case ODD_C: return os << "ODD_C";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
{
using enum GemmPadding;
switch(padding)
{
case DEFAULT: return os << "DEFAULT";
case M_PADDING: return os << "M_PADDING";
case N_PADDING: return os << "N_PADDING";
case K_PADDING: return os << "K_PADDING";
case MN_PADDING: return os << "MN_PADDING";
case MK_PADDING: return os << "MK_PADDING";
case NK_PADDING: return os << "NK_PADDING";
case MNK_PADDING: return os << "MNK_PADDING";
case O_PADDING: return os << "O_PADDING";
case MO_PADDING: return os << "MO_PADDING";
case NO_PADDING: return os << "NO_PADDING";
case KO_PADDING: return os << "KO_PADDING";
case MNO_PADDING: return os << "MNO_PADDING";
case MKO_PADDING: return os << "MKO_PADDING";
case NKO_PADDING: return os << "NKO_PADDING";
case MNKO_PADDING: return os << "MNKO_PADDING";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
{
using enum PipelineScheduler;
switch(sched)
{
case DEFAULT: return os << "DEFAULT";
case INTRAWAVE: return os << "INTRAWAVE";
case INTERWAVE: return os << "INTERWAVE";
default: return os << "Unknown";
}
}
// ostream operator overload for std::variant of layout types
inline std::ostream&
operator<<(std::ostream& os,
const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout)
{
std::visit([&os](const auto& l) { os << l; }, layout);
return os;
}
// ostream operator overload for std::variant of convolution specializations
inline std::ostream& operator<<(std::ostream& os,
const std::variant<ConvFwdSpecialization,
ConvBwdDataSpecialization,
ConvBwdWeightSpecialization>& spec)
{
std::visit([&os](const auto& s) { os << s; }, spec);
return os;
}
} // namespace ck_tile::builder

View File

@@ -67,6 +67,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test
add_ck_builder_test(test_conv_traits
conv/test_conv_traits.cpp)
add_ck_builder_test(test_conv_description
test_conv_description.cpp)
# Function to add all test_ckb targets to a list
function(collect_test_ckb_targets result_var)
# Get all targets in current directory

View File

@@ -0,0 +1,169 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include "testing_utils.hpp"
#include "impl/conv_signature_types.hpp"
#include "impl/conv_algorithm_types.hpp"
namespace {
namespace ckb = ck_tile::builder;
namespace ckr = ck_tile::reflect::conv;
namespace ckt = ck_tile::test;
// Defines the signature of the convolution operation to be tested.
// This includes dimensionality, direction, data layout, and data type.
struct ConvSignature
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
ckb::DataType data_type = ckb::DataType::FP16;
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
ckb::GroupConvDeviceOp device_operation =
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
struct DefaultAlgorithm
{
ckb::test::ThreadBlock thread_block{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 16,
.n_per_xdl = 16,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4};
ckb::test::BlockTransferABC block_transfer{
.block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8},
.block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = true,
.lds_padding = false},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = true,
.lds_padding = false},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {.order = {0, 1, 2}},
.block_transfer_access_order_b = {.order = {0, 1, 2}},
.src_access_order_a = {.order = {0, 1, 2}},
.src_access_order_b = {.order = {0, 1, 2}}};
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
};
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
{
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
EXPECT_THAT(ckr::Describe<Builder>().brief(), ckt::StringEqWithDiff("2D Forward convolution"));
}
TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
{
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
EXPECT_THAT(ckr::Describe<Builder>().detailed(),
ckt::StringEqWithDiff( //
"2D Forward Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"├─ Algorithm\n"
"│ ├─ Thread block size: 256\n"
"│ ├─ Data tile size: 256×256×32\n"
"│ ├─ Gemm padding: DEFAULT\n"
"│ ├─ Convolution specialization: DEFAULT\n"
"│ ├─ Pipeline version: V4\n"
"│ ├─ Pipeline scheduler: INTRAWAVE\n"
"│ ├─ Warp Gemm parameters: \n"
"│ │ ├─ subtile size: 16×16\n"
"│ │ └─ Number of warp gemm iterations: 4×4\n"
"│ ├─ Memory access:\n"
"│ │ ├─ A Tile transfer: \n"
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
"│ │ │ ├─ The innermost K subdimension size: 8\n"
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
"│ │ ├─ B Tile transfer: \n"
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
"│ │ │ ├─ The innermost K subdimension size: 8\n"
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
"│ │ └─ C Tile transfer: \n"
"│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
"│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
"│ │ └─ Vector access (GMEM write) instruction size: 8\n"
"│ └─ \n"
"└─ "));
}
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory
// does not have a specialization for backward data convolutions. The test fails with:
// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'"
//
// To enable this test, a ConvFactory specialization for backward data operations must be
// implemented first.
//
// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription)
// {
// struct BackwardDataSignature
// {
// int spatial_dim = 2;
// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA;
// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
// ckb::DataType data_type = ckb::DataType::FP16;
// ckb::ElementwiseOperation elementwise_operation =
// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation =
// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
// };
// static_assert(ckb::ConvSignatureDescriptor<BackwardDataSignature>);
//
// static constexpr const BackwardDataSignature SIGNATURE;
// static constexpr const DefaultAlgorithm ALGORITHM;
// using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
//
// // Verify Brief works
// EXPECT_THAT(ckr::Describe<Builder>().brief(),
// ckt::StringEqWithDiff("2D Backward Data convolution"));
//
// // Verify detailed works - to be updated once ConvFactory is implemented
// EXPECT_THAT(ckr::Describe<Builder>().detailed(),
// ckt::StringEqWithDiff("PLACEHOLDER"));
// }
} // namespace