Separate types from concepts.

This commit is contained in:
Ville Pietilä
2025-10-10 13:34:57 +00:00
parent 81ac06d29a
commit fe7ed96c2a
9 changed files with 249 additions and 263 deletions

View File

@@ -1,38 +1,37 @@
#include "ck_tile/builder/conv_builder.hpp"
#include "../utils/types.hpp"
#include "ck_tile/builder/conv_signature_types.hpp"
int main()
{
namespace ckb = ck_tile::builder;
namespace ckb_examples = ck_tile::builder::examples;
using namespace ck_tile::builder;
constexpr ckb_examples::ConvSignature FwdConvSignature
constexpr ConvSignature FwdConvSignature
{
.spatial_dim = 2,
.direction = ckb::ConvDirection::FORWARD,
.layout = ckb::GroupConvLayout::CHANNELS_LAST,
.data_type = ckb::DataType::BF16,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout::CHANNELS_LAST,
.data_type = DataType::BF16
};
static_assert(ckb::ValidConvSignature<FwdConvSignature>);
static_assert(ValidConvSignature<FwdConvSignature>);
// To get valid configuration parameters, refer to "device_grouped_conv_fwd_xdl_comp_instance.hpp".
// This file contains the current instances of the kernel DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3.
// Currently the build has this kernel hard-coded.
// In the future, we may need builders per kernel type since they typically have slightly different parameters.
constexpr ckb::ThreadBlock FwdThreadBlock
constexpr ThreadBlock FwdThreadBlock
{
.block_size = 256,
.submatrix = {.m = 256, .n = 256, .k = 32} // Tile sizes
};
constexpr ckb::ConvTuningParams FwdTuningParams
constexpr ConvTuningParams FwdTuningParams
{
.ak1 = 8, .bk1 = 8, .m_per_xdl=32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4
};
constexpr ckb_examples::BlockTransfer FwdBlockTransfer
constexpr BlockTransfer FwdBlockTransfer
{
.thread_cluster_dims_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_dims_b = {.k0 = 4, .n = 64, .k1 = 1},
@@ -50,15 +49,15 @@ int main()
.b_source_access_order = {1, 0, 2}
};
constexpr ckb_examples::ConvAlgorithm FwdConvAlgorithm
constexpr ConvAlgorithm FwdConvAlgorithm
{
.thread_block = FwdThreadBlock,
.tuning_params = FwdTuningParams,
.block_transfer = FwdBlockTransfer,
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
.pipeline_version = BlockGemmPipelineVersion::V4,
};
using Builder = ckb::ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
const auto kernel_string = Builder::Instance::TypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;

View File

@@ -1,53 +0,0 @@
#pragma once
#include "ck_tile/builder/conv_algorithm.hpp"
#include "ck_tile/builder/conv_signature.hpp"
namespace ck_tile::builder::examples
{
namespace ckb = ck_tile::builder;
struct ConvSignature {
int spatial_dim;
ckb::ConvDirection direction;
ckb::GroupConvLayout layout;
ckb::DataType data_type;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
struct BlockTransfer
{
ckb::BlockATransferLengths thread_cluster_dims_a;
ckb::BlockBTransferLengths thread_cluster_dims_b;
ckb::BlockCTransferLengths thread_cluster_dims_c;
ckb::VectorTransferAB vector_transfer_a;
ckb::VectorTransferAB vector_transfer_b;
ckb::VectorTransferC vector_transfer_c;
ckb::ThreadClusterAccessOrder a_thread_cluster_access_order;
ckb::ThreadClusterAccessOrder b_thread_cluster_access_order;
ckb::SourceAccessOrder a_source_access_order;
ckb::SourceAccessOrder b_source_access_order;
};
struct ConvAlgorithm
{
ckb::ThreadBlock thread_block;
ckb::ConvTuningParams tuning_params;
BlockTransfer block_transfer;
ckb::BlockGemmPipelineVersion pipeline_version;
};
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
static_assert(ckb::SpecifiesConvTuning<ConvAlgorithm>);
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
static_assert(ckb::SpecifiesBlockATransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesBlockBTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesBlockCTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesBlockAVectorTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesBlockBVectorTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesBlockCVectorTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesAThreadClusterAccessOrder<ConvAlgorithm>);
static_assert(ckb::SpecifiesBThreadClusterAccessOrder<ConvAlgorithm>);
static_assert(ckb::SpecifiesASourceAccessOrder<ConvAlgorithm>);
static_assert(ckb::SpecifiesBSourceAccessOrder<ConvAlgorithm>);
} // namespace ck_tile::builder::examples

View File

@@ -67,4 +67,41 @@ struct UnsupportedEnumValue
{
};
// Helper functions to convert enums to strings
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
{
switch(dir)
{
case ConvDirection::FORWARD: return "Forward";
case ConvDirection::BACKWARD_DATA: return "Backward Data";
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
default: return "Unknown";
}
}
constexpr std::string_view DataTypeToString(DataType dt)
{
switch(dt)
{
case DataType::FP16: return "FP16";
case DataType::FP32: return "FP32";
case DataType::FP64: return "FP64";
case DataType::BF16: return "BF16";
case DataType::S16: return "S16";
case DataType::S8: return "S8";
case DataType::S4: return "S4";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout layout)
{
switch(layout)
{
case GroupConvLayout::CHANNELS_FIRST: return "Channels-first (NCHW)";
case GroupConvLayout::CHANNELS_LAST: return "Channels-last (NHWC)";
default: return "Unknown";
}
}
} // namespace ck_tile::builder

View File

@@ -3,23 +3,10 @@
#include <type_traits>
#include <concepts>
#include <array>
#include "types.hpp"
namespace ck_tile::builder {
// TODO: VP (Oct 3, 2025) - Separate the concepts and structs into separate files.
// Concepts the define interface and structs are PODs that implement the concepts.
// The interface is really just the concepts. Clients can define their own structs
// as long as they satisfy the concepts.
// Convenience struct for a tuple of m, n, and k values.
template <typename T>
struct MNK
{
T m{};
T n{};
T k{};
};
// Concept for thread block dimensions for a GEMM problem.
template <typename T>
concept ThreadBlockDescriptor = requires(T t) {
@@ -29,16 +16,6 @@ concept ThreadBlockDescriptor = requires(T t) {
{ t.submatrix.k } -> std::convertible_to<int>;
};
// Specifiy thread block dimensions for a GEMM.
struct ThreadBlock
{
// Thread block size.
int block_size;
// Size of the submatrix problem in a thread block.
MNK<int> submatrix;
};
static_assert(ThreadBlockDescriptor<ThreadBlock>);
// Concept to check if struct specifies thread block info.
template <typename T>
concept SpecifiesThreadBlock = requires {
@@ -56,19 +33,6 @@ concept ConvTuningDescriptor = requires(T t) {
{ t.n_xdl_per_wave } -> std::convertible_to<int>;
};
// Describe some convolution tuning parameters.
struct ConvTuningParams
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
int ak1 = 0;
int bk1 = 0;
int m_per_xdl = 0;
int n_per_xdl = 0;
int m_xdl_per_wave = 0;
int n_xdl_per_wave = 0;
};
static_assert(ConvTuningDescriptor<ConvTuningParams>);
// Concept to check if a struct specifies convolution tuning info.
template <typename T>
concept SpecifiesConvTuning = requires {
@@ -83,15 +47,6 @@ concept BlockATransferDescriptor = requires(T t) {
{ t.k1 } -> std::convertible_to<int>;
};
// Describe A block transfer thread cluster lengths.
struct BlockATransferLengths
{
int k0;
int m;
int k1;
};
static_assert(BlockATransferDescriptor<BlockATransferLengths>);
// Concept for B block transfer thread cluster lengths.
template <typename T>
concept BlockBTransferDescriptor = requires(T t) {
@@ -100,43 +55,18 @@ concept BlockBTransferDescriptor = requires(T t) {
{ t.k1 } -> std::convertible_to<int>;
};
// Describe B block transfer thread cluster lengths.
struct BlockBTransferLengths
{
int k0;
int n;
int k1;
};
static_assert(BlockBTransferDescriptor<BlockBTransferLengths>);
// Concept for the thread cluster access order
template <typename T>
concept ThreadClusterAccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<int, 3>>;
};
// Describe the thread cluster access order.
struct ThreadClusterAccessOrder
{
// Order of the cluster dimensions. Must be a permutation of {0, 1, 2}.
std::array<int, 3> order;
};
static_assert(ThreadClusterAccessOrderDescriptor<ThreadClusterAccessOrder>);
// Concept to describe source access order
template <typename T>
concept SourceAccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<int, 3>>;
};
// Describe the source access order.
struct SourceAccessOrder
{
// Order of the source dimensions. Must be a permutation of {0, 1, 2}.
std::array<int, 3> order;
};
static_assert(SourceAccessOrderDescriptor<SourceAccessOrder>);
// Concept for C block transfer thread cluster lengths.
template <typename T>
concept BlockCTransferDescriptor = requires(T t) {
@@ -146,16 +76,6 @@ concept BlockCTransferDescriptor = requires(T t) {
{ t.n_wave_per_xdl } -> std::convertible_to<int>;
};
// Describe C block transfer thread cluster lengths.
struct BlockCTransferLengths
{
int m_block;
int m_wave_per_xdl;
int n_block;
int n_wave_per_xdl;
};
static_assert(BlockCTransferDescriptor<BlockCTransferLengths>);
// Concept for vector transfer details for A and B tensors
template <typename T>
concept VectorTransferDescriptorAB = requires(T t) {
@@ -165,15 +85,6 @@ concept VectorTransferDescriptorAB = requires(T t) {
{ t.add_extra } -> std::convertible_to<bool>;
};
struct VectorTransferAB
{
size_t src_vector_dim;
size_t src_scaler_per_vector;
size_t dest_scaler_per_vector_k1;
bool add_extra;
};
static_assert(VectorTransferDescriptorAB<VectorTransferAB>);
// Concept for the C tensor vectors transfer details.
template <typename T>
concept VectorTransferDescriptorC = requires(T t) {
@@ -182,14 +93,6 @@ concept VectorTransferDescriptorC = requires(T t) {
{ t.scaler_per_vector } -> std::convertible_to<size_t>;
};
struct VectorTransferC
{
size_t m_xdl_per_wave_per_shuffle;
size_t n_xdl_per_wave_per_shuffle;
size_t scaler_per_vector;
};
static_assert(VectorTransferDescriptorC<VectorTransferC>);
// Concept to check if a struct specifies A Block tranfer info.
template <typename T>
concept SpecifiesBlockATransfer = requires(T t) {
@@ -250,15 +153,6 @@ concept SpecifiesBSourceAccessOrder = requires(T t) {
{ T::block_transfer.b_source_access_order } -> SourceAccessOrderDescriptor;
};
// Enums for the current block GEMM pipeline versions.
enum class BlockGemmPipelineVersion
{
V1,
V3,
V4,
V5
};
// Concept to check if struct specifies block_gemm_pipeline_version.
template <typename T>
concept SpecifiesGemmPipelineVersion = requires {

View File

@@ -0,0 +1,135 @@
#pragma once
#include "conv_algorithm_concepts.hpp"
namespace ck_tile::builder {
// Convenience struct for a tuple of m, n, and k values.
template <typename T>
struct MNK
{
T m{};
T n{};
T k{};
};
// Specifiy thread block dimensions for a GEMM.
struct ThreadBlock
{
// Thread block size.
int block_size;
// Size of the submatrix problem in a thread block.
MNK<int> submatrix;
};
static_assert(ThreadBlockDescriptor<ThreadBlock>);
// Describe some convolution tuning parameters.
struct ConvTuningParams
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
int ak1 = 0;
int bk1 = 0;
int m_per_xdl = 0;
int n_per_xdl = 0;
int m_xdl_per_wave = 0;
int n_xdl_per_wave = 0;
};
static_assert(ConvTuningDescriptor<ConvTuningParams>);
// Describe A block transfer thread cluster lengths.
struct BlockATransferLengths
{
int k0;
int m;
int k1;
};
static_assert(BlockATransferDescriptor<BlockATransferLengths>);
// Describe B block transfer thread cluster lengths.
struct BlockBTransferLengths
{
int k0;
int n;
int k1;
};
static_assert(BlockBTransferDescriptor<BlockBTransferLengths>);
// Describe the thread cluster access order.
struct ThreadClusterAccessOrder
{
// Order of the cluster dimensions. Must be a permutation of {0, 1, 2}.
std::array<int, 3> order;
};
static_assert(ThreadClusterAccessOrderDescriptor<ThreadClusterAccessOrder>);
// Describe the source access order.
struct SourceAccessOrder
{
// Order of the source dimensions. Must be a permutation of {0, 1, 2}.
std::array<int, 3> order;
};
static_assert(SourceAccessOrderDescriptor<SourceAccessOrder>);
// Describe C block transfer thread cluster lengths.
struct BlockCTransferLengths
{
int m_block;
int m_wave_per_xdl;
int n_block;
int n_wave_per_xdl;
};
static_assert(BlockCTransferDescriptor<BlockCTransferLengths>);
struct VectorTransferAB
{
size_t src_vector_dim;
size_t src_scaler_per_vector;
size_t dest_scaler_per_vector_k1;
bool add_extra;
};
static_assert(VectorTransferDescriptorAB<VectorTransferAB>);
struct VectorTransferC
{
size_t m_xdl_per_wave_per_shuffle;
size_t n_xdl_per_wave_per_shuffle;
size_t scaler_per_vector;
};
static_assert(VectorTransferDescriptorC<VectorTransferC>);
struct BlockTransfer
{
BlockATransferLengths thread_cluster_dims_a;
BlockBTransferLengths thread_cluster_dims_b;
BlockCTransferLengths thread_cluster_dims_c;
VectorTransferAB vector_transfer_a;
VectorTransferAB vector_transfer_b;
VectorTransferC vector_transfer_c;
ThreadClusterAccessOrder a_thread_cluster_access_order;
ThreadClusterAccessOrder b_thread_cluster_access_order;
SourceAccessOrder a_source_access_order;
SourceAccessOrder b_source_access_order;
};
struct ConvAlgorithm
{
ThreadBlock thread_block;
ConvTuningParams tuning_params;
BlockTransfer block_transfer;
BlockGemmPipelineVersion pipeline_version;
};
static_assert(ConvAlgorithmDescriptor<ConvAlgorithm>);
static_assert(SpecifiesThreadBlock<ConvAlgorithm>);
static_assert(SpecifiesConvTuning<ConvAlgorithm>);
static_assert(SpecifiesGemmPipelineVersion<ConvAlgorithm>);
static_assert(SpecifiesBlockATransfer<ConvAlgorithm>);
static_assert(SpecifiesBlockBTransfer<ConvAlgorithm>);
static_assert(SpecifiesBlockCTransfer<ConvAlgorithm>);
static_assert(SpecifiesBlockAVectorTransfer<ConvAlgorithm>);
static_assert(SpecifiesBlockBVectorTransfer<ConvAlgorithm>);
static_assert(SpecifiesBlockCVectorTransfer<ConvAlgorithm>);
static_assert(SpecifiesAThreadClusterAccessOrder<ConvAlgorithm>);
static_assert(SpecifiesBThreadClusterAccessOrder<ConvAlgorithm>);
static_assert(SpecifiesASourceAccessOrder<ConvAlgorithm>);
static_assert(SpecifiesBSourceAccessOrder<ConvAlgorithm>);
} // namespace ck_tile::builder

View File

@@ -3,9 +3,7 @@
#include <concepts>
#include <type_traits>
#include <ck_tile/builder/conv_algorithm.hpp>
#include <ck_tile/builder/conv_factory.hpp>
#include <ck_tile/builder/conv_signature.hpp>
#include <ck_tile/builder/versions.hpp>
namespace ck_tile::builder {

View File

@@ -33,12 +33,12 @@
#include <ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp>
#include <ck_tile/builder/conv_signature.hpp>
#include <ck_tile/builder/conv_algorithm.hpp>
#include <ck_tile/builder/conv_algorithm_types.hpp>
#include <ck_tile/builder/builder_utils.hpp>
#include <ck_tile/builder/types.hpp>
#include <ck_tile/builder/versions.hpp>
namespace ck_tile::builder {
namespace ck_tile::builder::factory_internal {
// Type mappings from the builder GroupConvLayout enum class to the CK tensor data types.
template <GroupConvLayout Layout, int SPATIAL_DIM, ConvDirection DIR>
@@ -374,6 +374,10 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
return ck::BlockGemmPipelineVersion::v4;
}
} // namespace ck_tile::builder::factory
namespace ck_tile::builder {
// Primary template for the convolution factory.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
@@ -388,20 +392,20 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = ConvTensorTypes<SIGNATURE.data_type>;
using Ops = ConvPassThroughOps;
static constexpr ConvSpec SPECIALIZATION{
using Layouts = factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ConvPassThroughOps;
static constexpr factory_internal::ConvSpec SPECIALIZATION{
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<SIGNATURE, ALGORITHM>();
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetFwdConvABlockTransfer<ALGORITHM>();
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetFwdConvBBlockTransfer<ALGORITHM>();
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion<ALGORITHM>();
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto TUNING = factory_internal::SetConvTuningInfo<SIGNATURE, ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER = factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
static constexpr auto B_BLOCK_TRANSFER = factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION = factory_internal::SetBlockGemmPipelineVersion<ALGORITHM>();
// The forward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
@@ -461,20 +465,20 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::BACKWARD_DATA>;
using Types = ConvTensorTypes<SIGNATURE.data_type>;
using Ops = ConvPassThroughOps;
static constexpr ConvSpec SPECIALIZATION{
using Layouts = factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::BACKWARD_DATA>;
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ConvPassThroughOps;
static constexpr factory_internal::ConvSpec SPECIALIZATION{
.conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<SIGNATURE, ALGORITHM>();
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetBwdDataConvABlockTransfer<ALGORITHM>();
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBwdDataConvBBlockTransfer<ALGORITHM>();
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion<ALGORITHM>();
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto TUNING = factory_internal::SetConvTuningInfo<SIGNATURE, ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER = factory_internal::SetBwdDataConvABlockTransfer<ALGORITHM>();
static constexpr auto B_BLOCK_TRANSFER = factory_internal::SetBwdDataConvBBlockTransfer<ALGORITHM>();
static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION = factory_internal::SetBlockGemmPipelineVersion<ALGORITHM>();
// The backward-data convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<

View File

@@ -21,13 +21,6 @@
namespace ck_tile::builder {
// Memory layouts for convolution tensors, following PyTorch conventions.
enum class GroupConvLayout
{
CHANNELS_LAST, // e.g., NHWGC
CHANNELS_FIRST // e.g., NGCHW
};
// Constrains convolution to 1D, 2D, or 3D spatial dimensions.
template <auto N>
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
@@ -36,25 +29,6 @@ concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 ||
template <DataType T>
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16);
// Direction of the convolution operation.
enum class ConvDirection
{
FORWARD,
BACKWARD_DATA,
BACKWARD_WEIGHT
};
// Fused element-wise operations.
enum class ElementwiseOperation
{
BIAS,
BIAS_CLAMP,
BILINEAR,
CLAMP,
SCALE,
PASS_THROUGH
};
// Concept for a type that defines a convolution's operational signature.
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
@@ -83,41 +57,4 @@ concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
// Helper functions to convert enums to strings
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
{
switch(dir)
{
case ConvDirection::FORWARD: return "Forward";
case ConvDirection::BACKWARD_DATA: return "Backward Data";
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
default: return "Unknown";
}
}
constexpr std::string_view DataTypeToString(DataType dt)
{
switch(dt)
{
case DataType::FP16: return "FP16";
case DataType::FP32: return "FP32";
case DataType::FP64: return "FP64";
case DataType::BF16: return "BF16";
case DataType::S16: return "S16";
case DataType::S8: return "S8";
case DataType::S4: return "S4";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout layout)
{
switch(layout)
{
case GroupConvLayout::CHANNELS_FIRST: return "Channels-first (NCHW)";
case GroupConvLayout::CHANNELS_LAST: return "Channels-last (NHWC)";
default: return "Unknown";
}
}
} // namespace ck_tile::builder

View File

@@ -13,4 +13,39 @@ enum class DataType
S4,
};
// Memory layouts for convolution tensors, following PyTorch conventions.
enum class GroupConvLayout
{
CHANNELS_LAST, // e.g., NHWGC
CHANNELS_FIRST // e.g., NGCHW
};
// Direction of the convolution operation.
enum class ConvDirection
{
FORWARD,
BACKWARD_DATA,
BACKWARD_WEIGHT
};
// Fused element-wise operations.
enum class ElementwiseOperation
{
BIAS,
BIAS_CLAMP,
BILINEAR,
CLAMP,
SCALE,
PASS_THROUGH
};
// Enums for the current block GEMM pipeline versions.
enum class BlockGemmPipelineVersion
{
V1,
V3,
V4,
V5
};
} // namespace ck_tile::builder