mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_BUILDER] Convolution description (#3163)
* Add DirectLoad tparam & clean up headers. * Add convolution traits. * Update inline documentation. * Add more convolution specialization and gemm padding types. * Add additional helper functions & more tests to conv traits. * Fix tests cmake file. * Add case insensitive string comparison * Fix function name overlapping with variable name. * Unify pipeline version and scheduler enums. * Fix includes. * Update test conv traits with unified enums. * Update concepts etc with update unified enum * Fix ckb conv fwd test - unified enum usage. * Dump changes. * Add ostream overloads for all enum classes. * Update detailed() function in ConvDescription * Fix handling union based conv direction. * Add test & update conv description. * Refine tree view. * Update copyrights * Fix merge artifacts * Update detailed tree conv description * Fix clang-format
This commit is contained in:
@@ -78,66 +78,4 @@ struct UnsupportedEnumValue
|
||||
{
|
||||
};
|
||||
|
||||
// Helper functions to convert enums to strings
|
||||
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
|
||||
{
|
||||
switch(dir)
|
||||
{
|
||||
case ConvDirection::FORWARD: return "Forward";
|
||||
case ConvDirection::BACKWARD_DATA: return "Backward Data";
|
||||
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view DataTypeToString(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP16: return "FP16";
|
||||
case DataType::FP32: return "FP32";
|
||||
case DataType::BF16: return "BF16";
|
||||
case DataType::FP8: return "FP8";
|
||||
case DataType::I8: return "I8";
|
||||
case DataType::U8: return "U8";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
|
||||
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
|
||||
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
|
||||
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
|
||||
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
|
||||
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
|
||||
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
|
||||
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
|
||||
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
|
||||
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
|
||||
|
||||
@@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward =
|
||||
// Predicate for DeviceGroupedConvBwdWeight operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
|
||||
|
||||
@@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight =
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);
|
||||
|
||||
|
||||
@@ -0,0 +1,268 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <variant>
|
||||
|
||||
#include <ck_tile/builder/conv_signature_concepts.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/tree_formatter.hpp>
|
||||
|
||||
/// @file conv_description.hpp
|
||||
/// @brief Provides human-readable descriptions of ConvBuilder configurations
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
struct ConvSignatureInfo
|
||||
{
|
||||
int spatial_dim;
|
||||
builder::ConvDirection direction;
|
||||
std::variant<builder::GroupConvLayout1D, builder::GroupConvLayout2D, builder::GroupConvLayout3D>
|
||||
layout;
|
||||
builder::DataType data_type;
|
||||
builder::ElementwiseOperation input_element_op;
|
||||
builder::ElementwiseOperation weight_element_op;
|
||||
builder::ElementwiseOperation output_element_op;
|
||||
};
|
||||
|
||||
// Algorithm information - groups all algorithm-related configuration
|
||||
struct GemmAlgorithmInfo
|
||||
{
|
||||
int thread_block_size;
|
||||
DataTileInfo tile_dims;
|
||||
WarpGemmParams warp_gemm;
|
||||
InputTileTransferInfo a_tile_transfer;
|
||||
InputTileTransferInfo b_tile_transfer;
|
||||
OutputTileTransferInfo c_tile_transfer;
|
||||
builder::PipelineVersion pipeline_version;
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
std::variant<builder::ConvFwdSpecialization,
|
||||
builder::ConvBwdDataSpecialization,
|
||||
builder::ConvBwdWeightSpecialization>
|
||||
conv_specialization;
|
||||
builder::GemmPadding padding;
|
||||
};
|
||||
|
||||
// Provides human-readable descriptions of ConvBuilder configurations.
|
||||
struct ConvDescription
|
||||
{
|
||||
ConvSignatureInfo signature;
|
||||
GemmAlgorithmInfo algorithm;
|
||||
|
||||
// Brief one-line summary
|
||||
std::string brief() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << signature.spatial_dim << "D " << signature.direction << " convolution";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Detailed hierarchical description
|
||||
std::string detailed() const
|
||||
{
|
||||
TreeFormatter f;
|
||||
f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel");
|
||||
f.writeLine(1, "Signature");
|
||||
f.writeLine(2, "Tensor Type: ", signature.data_type);
|
||||
f.writeLine(2, "Memory Layout: ", signature.layout);
|
||||
f.writeLine(2, "Input elementwise operation: ", signature.input_element_op);
|
||||
f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op);
|
||||
f.writeLast(2, "Output elementwise operation: ", signature.output_element_op);
|
||||
|
||||
f.writeLine(1, "Algorithm");
|
||||
// Compute Block section
|
||||
f.writeLine(2, "Thread block size: ", algorithm.thread_block_size);
|
||||
f.writeLine(2,
|
||||
"Data tile size: ",
|
||||
algorithm.tile_dims.m,
|
||||
"×",
|
||||
algorithm.tile_dims.n,
|
||||
"×",
|
||||
algorithm.tile_dims.k);
|
||||
f.writeLine(2, "Gemm padding: ", algorithm.padding);
|
||||
f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization);
|
||||
// Pipeline section
|
||||
f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version);
|
||||
f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler);
|
||||
f.writeLine(2, "Warp Gemm parameters: ");
|
||||
f.writeLine(
|
||||
3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n);
|
||||
f.writeLast(3,
|
||||
"Number of warp gemm iterations: ",
|
||||
algorithm.warp_gemm.m_iter,
|
||||
"×",
|
||||
algorithm.warp_gemm.n_iter);
|
||||
|
||||
// Memory Access section
|
||||
f.writeLine(2, "Memory access:");
|
||||
|
||||
f.writeLine(3, "A Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Tile dimensions: ",
|
||||
algorithm.a_tile_transfer.tile_dimensions.k0,
|
||||
"×",
|
||||
algorithm.a_tile_transfer.tile_dimensions.m_or_n,
|
||||
"×",
|
||||
algorithm.a_tile_transfer.tile_dimensions.k1,
|
||||
"×");
|
||||
f.writeLine(
|
||||
4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution over the data tile: ",
|
||||
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
f.writeLine(4,
|
||||
"The order of accessing data tile axes: ",
|
||||
algorithm.a_tile_transfer.transfer_params.src_access_order[0],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.src_access_order[1],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.src_access_order[2]);
|
||||
f.writeLine(4,
|
||||
"Vectorized memory access axis index (with contiguous memory): ",
|
||||
algorithm.a_tile_transfer.transfer_params.src_vector_dim);
|
||||
f.writeLine(4,
|
||||
"Vector access (GMEM read) instruction size: ",
|
||||
algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
f.writeLine(4,
|
||||
"Vector access (LDS write) instruction size: ",
|
||||
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
f.writeLast(4,
|
||||
"LDS data layout padding (to prevent bank conflicts): ",
|
||||
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
|
||||
f.writeLine(3, "B Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Tile dimensions: ",
|
||||
algorithm.b_tile_transfer.tile_dimensions.k0,
|
||||
"×",
|
||||
algorithm.b_tile_transfer.tile_dimensions.m_or_n,
|
||||
"×",
|
||||
algorithm.b_tile_transfer.tile_dimensions.k1,
|
||||
"×");
|
||||
f.writeLine(
|
||||
4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution over the data tile: ",
|
||||
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
f.writeLine(4,
|
||||
"The order of accessing data tile axes: ",
|
||||
algorithm.b_tile_transfer.transfer_params.src_access_order[0],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.src_access_order[1],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.src_access_order[2]);
|
||||
f.writeLine(4,
|
||||
"Vectorized memory access axis index (with contiguous memory): ",
|
||||
algorithm.b_tile_transfer.transfer_params.src_vector_dim);
|
||||
f.writeLine(4,
|
||||
"Vector access (GMEM read) instruction size: ",
|
||||
algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
f.writeLine(4,
|
||||
"Vector access (LDS write) instruction size: ",
|
||||
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
f.writeLast(4,
|
||||
"LDS data layout padding (to prevent bank conflicts): ",
|
||||
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
|
||||
f.writeLast(3, "C Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Data shuffle (number of gemm instructions per iteration): ",
|
||||
algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
|
||||
"×",
|
||||
algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution used to store data: ",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[0],
|
||||
"×",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[1],
|
||||
"×",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[2],
|
||||
"×",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[3]);
|
||||
f.writeLast(4,
|
||||
"Vector access (GMEM write) instruction size: ",
|
||||
algorithm.c_tile_transfer.scalar_per_vector);
|
||||
f.writeLast(2);
|
||||
f.writeLast(1);
|
||||
return f.getString();
|
||||
}
|
||||
|
||||
// Educational explanation of optimization choices
|
||||
std::string explain() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
// Placeholder for future implementation
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Performance characteristics and use case guidance
|
||||
std::string suggest() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
// Placeholder for future implementation
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
// Helper concept to detect if a type has InstanceTraits specialization
|
||||
template <typename T>
|
||||
concept HasInstanceTraits = requires { typename InstanceTraits<T>; };
|
||||
|
||||
// Helper concept to detect ConvBuilder types
|
||||
template <typename T>
|
||||
concept IsConvBuilder = requires {
|
||||
typename T::Factory;
|
||||
typename T::Instance;
|
||||
};
|
||||
|
||||
// Primary factory function: Create ConvDescription from Instance type directly
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance>
|
||||
ConvDescription Describe()
|
||||
{
|
||||
using Traits = ConvTraits<Instance>;
|
||||
|
||||
return ConvDescription{
|
||||
.signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim,
|
||||
.direction = Traits::direction,
|
||||
.layout = Traits::layout,
|
||||
.data_type = Traits::data_type,
|
||||
.input_element_op = Traits::input_element_op,
|
||||
.weight_element_op = Traits::weight_element_op,
|
||||
.output_element_op = Traits::output_element_op},
|
||||
.algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size,
|
||||
.tile_dims = Traits::tile_dims,
|
||||
.warp_gemm = Traits::warp_gemm,
|
||||
.a_tile_transfer = Traits::a_tile_transfer,
|
||||
.b_tile_transfer = Traits::b_tile_transfer,
|
||||
.c_tile_transfer = Traits::c_tile_transfer,
|
||||
.pipeline_version = Traits::pipeline_version,
|
||||
.pipeline_scheduler = Traits::pipeline_scheduler,
|
||||
.conv_specialization = Traits::conv_specialization,
|
||||
.padding = Traits::gemm_padding}};
|
||||
}
|
||||
|
||||
// Backward compatibility: Create ConvDescription from Builder type
|
||||
template <typename Builder>
|
||||
requires IsConvBuilder<Builder> && (!HasInstanceTraits<Builder>)
|
||||
ConvDescription Describe()
|
||||
{
|
||||
// Delegate to Instance-based version
|
||||
using Instance = typename Builder::Instance;
|
||||
return Describe<Instance>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -13,7 +13,10 @@
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <climits>
|
||||
#include <limits.h>
|
||||
#include <cmath>
|
||||
#include <ostream>
|
||||
#include <iostream>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Helper class for formatting hierarchical tree structures with proper indentation
|
||||
// and tree-drawing characters (├─, └─, │, etc.)
|
||||
//
|
||||
// Example Usage:
|
||||
//
|
||||
// TreeFormatter f;
|
||||
// f.writeLine(0, "Root");
|
||||
// f.writeLine(1, "Branch 1");
|
||||
// f.writeLine(2, "Item 1a");
|
||||
// f.writeLast(2, "Item 1b");
|
||||
// f.writeLast(1, "Branch 2");
|
||||
// f.writeLast(2, "Item 2a");
|
||||
// std::cout << f.getString() << "\n";
|
||||
//
|
||||
// Generated Output:
|
||||
//
|
||||
// Root
|
||||
// ├─ Branch 1
|
||||
// │ ├─ Item 1a
|
||||
// │ └─ Item 1b
|
||||
// └─ Branch 2
|
||||
// └─ Item 2a
|
||||
class TreeFormatter
|
||||
{
|
||||
public:
|
||||
TreeFormatter() = default;
|
||||
|
||||
// Write a line at the specified indentation level (branch continues after this)
|
||||
template <typename... Args>
|
||||
void writeLine(int indent_level, Args&&... args)
|
||||
{
|
||||
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Write the last line at the specified indentation level (branch ends)
|
||||
template <typename... Args>
|
||||
void writeLast(int indent_level, Args&&... args)
|
||||
{
|
||||
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Get the formatted string (removes trailing newline if present)
|
||||
std::string getString() const
|
||||
{
|
||||
std::string result = oss_.str();
|
||||
if(!result.empty() && result.back() == '\n')
|
||||
{
|
||||
result.pop_back();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::ostringstream oss_;
|
||||
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
|
||||
|
||||
// Implementation of line writing with tree symbols
|
||||
template <typename... Args>
|
||||
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
|
||||
{
|
||||
// Ensure we have enough tracking space
|
||||
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
|
||||
{
|
||||
is_last_at_level_.resize(indent_level + 1, false);
|
||||
// Level 0 (root) should always be treated as "last" since it has no tree symbols
|
||||
if(is_last_at_level_.size() > 0)
|
||||
{
|
||||
is_last_at_level_[0] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Draw the tree structure
|
||||
// Start from level 1 (skip level 0 which is the root with no symbols)
|
||||
for(int i = 1; i < indent_level; ++i)
|
||||
{
|
||||
// For all parent levels, draw vertical line or space based on whether they ended
|
||||
oss_ << (is_last_at_level_[i] ? " " : "│ ");
|
||||
}
|
||||
|
||||
// Draw the branch symbol for the current level
|
||||
if(indent_level > 0)
|
||||
{
|
||||
oss_ << (is_last ? "└─ " : "├─ ");
|
||||
}
|
||||
|
||||
// Write the content using fold expression with direct stream insertion
|
||||
((oss_ << std::forward<Args>(args)), ...);
|
||||
|
||||
oss_ << '\n';
|
||||
|
||||
// Update tracking for this level AFTER writing the line
|
||||
// This ensures future lines at deeper levels know if this level ended
|
||||
is_last_at_level_[indent_level] = is_last;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -3,6 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
@@ -215,4 +219,275 @@ enum class PipelineScheduler
|
||||
INTERWAVE
|
||||
};
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt)
|
||||
{
|
||||
using enum DataType;
|
||||
switch(dt)
|
||||
{
|
||||
case FP16: return os << "FP16";
|
||||
case FP32: return os << "FP32";
|
||||
case BF16: return os << "BF16";
|
||||
case FP8: return os << "FP8";
|
||||
case I8: return os << "I8";
|
||||
case U8: return os << "U8";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
|
||||
{
|
||||
using enum ConvDirection;
|
||||
switch(dir)
|
||||
{
|
||||
case FORWARD: return os << "Forward";
|
||||
case BACKWARD_DATA: return os << "Backward Data";
|
||||
case BACKWARD_WEIGHT: return os << "Backward Weight";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout)
|
||||
{
|
||||
using enum GroupConvLayout1D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK";
|
||||
case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK";
|
||||
case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW";
|
||||
case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout)
|
||||
{
|
||||
using enum GroupConvLayout2D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK";
|
||||
case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK";
|
||||
case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW";
|
||||
case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
|
||||
{
|
||||
using enum GroupConvLayout3D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK";
|
||||
case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK";
|
||||
case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW";
|
||||
case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum FwdGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK:
|
||||
return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
|
||||
case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
|
||||
case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdDataGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdWeightGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight";
|
||||
case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
case DeviceGroupedConvBwdWeight_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD";
|
||||
case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
switch(op)
|
||||
{
|
||||
case BIAS: return os << "BIAS";
|
||||
case BIAS_CLAMP: return os << "BIAS_CLAMP";
|
||||
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
|
||||
case BILINEAR: return os << "BILINEAR";
|
||||
case CLAMP: return os << "CLAMP";
|
||||
case SCALE: return os << "SCALE";
|
||||
case PASS_THROUGH: return os << "PASS_THROUGH";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
|
||||
{
|
||||
using enum PipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case V1: return os << "V1";
|
||||
case V2: return os << "V2";
|
||||
case V3: return os << "V3";
|
||||
case V4: return os << "V4";
|
||||
case V5: return os << "V5";
|
||||
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
|
||||
{
|
||||
using enum GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case Default: return os << "Default";
|
||||
case MPadding: return os << "MPadding";
|
||||
case NPadding: return os << "NPadding";
|
||||
case KPadding: return os << "KPadding";
|
||||
case MNPadding: return os << "MNPadding";
|
||||
case MKPadding: return os << "MKPadding";
|
||||
case NKPadding: return os << "NKPadding";
|
||||
case MNKPadding: return os << "MNKPadding";
|
||||
case OPadding: return os << "OPadding";
|
||||
case MOPadding: return os << "MOPadding";
|
||||
case NOPadding: return os << "NOPadding";
|
||||
case KOPadding: return os << "KOPadding";
|
||||
case MNOPadding: return os << "MNOPadding";
|
||||
case MKOPadding: return os << "MKOPadding";
|
||||
case NKOPadding: return os << "NKOPadding";
|
||||
case MNKOPadding: return os << "MNKOPadding";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
|
||||
{
|
||||
using enum ConvFwdSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_3x3: return os << "FILTER_3x3";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdDataSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdWeightSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case ODD_C: return os << "ODD_C";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
|
||||
{
|
||||
using enum GemmPadding;
|
||||
switch(padding)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case M_PADDING: return os << "M_PADDING";
|
||||
case N_PADDING: return os << "N_PADDING";
|
||||
case K_PADDING: return os << "K_PADDING";
|
||||
case MN_PADDING: return os << "MN_PADDING";
|
||||
case MK_PADDING: return os << "MK_PADDING";
|
||||
case NK_PADDING: return os << "NK_PADDING";
|
||||
case MNK_PADDING: return os << "MNK_PADDING";
|
||||
case O_PADDING: return os << "O_PADDING";
|
||||
case MO_PADDING: return os << "MO_PADDING";
|
||||
case NO_PADDING: return os << "NO_PADDING";
|
||||
case KO_PADDING: return os << "KO_PADDING";
|
||||
case MNO_PADDING: return os << "MNO_PADDING";
|
||||
case MKO_PADDING: return os << "MKO_PADDING";
|
||||
case NKO_PADDING: return os << "NKO_PADDING";
|
||||
case MNKO_PADDING: return os << "MNKO_PADDING";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
{
|
||||
using enum PipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case INTRAWAVE: return os << "INTRAWAVE";
|
||||
case INTERWAVE: return os << "INTERWAVE";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of layout types
|
||||
inline std::ostream&
|
||||
operator<<(std::ostream& os,
|
||||
const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout)
|
||||
{
|
||||
std::visit([&os](const auto& l) { os << l; }, layout);
|
||||
return os;
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of convolution specializations
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const std::variant<ConvFwdSpecialization,
|
||||
ConvBwdDataSpecialization,
|
||||
ConvBwdWeightSpecialization>& spec)
|
||||
{
|
||||
std::visit([&os](const auto& s) { os << s; }, spec);
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -67,6 +67,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test
|
||||
add_ck_builder_test(test_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
|
||||
add_ck_builder_test(test_conv_description
|
||||
test_conv_description.cpp)
|
||||
|
||||
# Function to add all test_ckb targets to a list
|
||||
function(collect_test_ckb_targets result_var)
|
||||
# Get all targets in current directory
|
||||
|
||||
169
experimental/builder/test/test_conv_description.cpp
Normal file
169
experimental/builder/test/test_conv_description.cpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_description.hpp>
|
||||
#include "testing_utils.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckr = ck_tile::reflect::conv;
|
||||
namespace ckt = ck_tile::test;
|
||||
|
||||
// Defines the signature of the convolution operation to be tested.
|
||||
// This includes dimensionality, direction, data layout, and data type.
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
|
||||
ckb::GroupConvDeviceOp device_operation =
|
||||
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
struct DefaultAlgorithm
|
||||
{
|
||||
ckb::test::ThreadBlock thread_block{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 16,
|
||||
.n_per_xdl = 16,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
ckb::test::BlockTransferABC block_transfer{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {.order = {0, 1, 2}},
|
||||
.block_transfer_access_order_b = {.order = {0, 1, 2}},
|
||||
.src_access_order_a = {.order = {0, 1, 2}},
|
||||
.src_access_order_b = {.order = {0, 1, 2}}};
|
||||
|
||||
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
|
||||
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
|
||||
{
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
EXPECT_THAT(ckr::Describe<Builder>().brief(), ckt::StringEqWithDiff("2D Forward convolution"));
|
||||
}
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
{
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
EXPECT_THAT(ckr::Describe<Builder>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"├─ Algorithm\n"
|
||||
"│ ├─ Thread block size: 256\n"
|
||||
"│ ├─ Data tile size: 256×256×32\n"
|
||||
"│ ├─ Gemm padding: DEFAULT\n"
|
||||
"│ ├─ Convolution specialization: DEFAULT\n"
|
||||
"│ ├─ Pipeline version: V4\n"
|
||||
"│ ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
"│ ├─ Warp Gemm parameters: \n"
|
||||
"│ │ ├─ subtile size: 16×16\n"
|
||||
"│ │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
"│ ├─ Memory access:\n"
|
||||
"│ │ ├─ A Tile transfer: \n"
|
||||
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
"│ │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
"│ │ ├─ B Tile transfer: \n"
|
||||
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
"│ │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
"│ │ └─ C Tile transfer: \n"
|
||||
"│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
"│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
"│ │ └─ Vector access (GMEM write) instruction size: 8\n"
|
||||
"│ └─ \n"
|
||||
"└─ "));
|
||||
}
|
||||
|
||||
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory
|
||||
// does not have a specialization for backward data convolutions. The test fails with:
|
||||
// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'"
|
||||
//
|
||||
// To enable this test, a ConvFactory specialization for backward data operations must be
|
||||
// implemented first.
|
||||
//
|
||||
// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription)
|
||||
// {
|
||||
// struct BackwardDataSignature
|
||||
// {
|
||||
// int spatial_dim = 2;
|
||||
// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA;
|
||||
// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
// ckb::DataType data_type = ckb::DataType::FP16;
|
||||
// ckb::ElementwiseOperation elementwise_operation =
|
||||
// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation =
|
||||
// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
|
||||
// };
|
||||
// static_assert(ckb::ConvSignatureDescriptor<BackwardDataSignature>);
|
||||
//
|
||||
// static constexpr const BackwardDataSignature SIGNATURE;
|
||||
// static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
// using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
//
|
||||
// // Verify Brief works
|
||||
// EXPECT_THAT(ckr::Describe<Builder>().brief(),
|
||||
// ckt::StringEqWithDiff("2D Backward Data convolution"));
|
||||
//
|
||||
// // Verify detailed works - to be updated once ConvFactory is implemented
|
||||
// EXPECT_THAT(ckr::Describe<Builder>().detailed(),
|
||||
// ckt::StringEqWithDiff("PLACEHOLDER"));
|
||||
// }
|
||||
} // namespace
|
||||
Reference in New Issue
Block a user