diff --git a/example/ck_builder/01_conv_2d_fwd/ckb_example_conv_fwd_2d.cpp b/example/ck_builder/01_conv_2d_fwd/ckb_example_conv_fwd_2d.cpp index abeaebe4ff..8bb4db99c8 100644 --- a/example/ck_builder/01_conv_2d_fwd/ckb_example_conv_fwd_2d.cpp +++ b/example/ck_builder/01_conv_2d_fwd/ckb_example_conv_fwd_2d.cpp @@ -1,38 +1,37 @@ #include "ck_tile/builder/conv_builder.hpp" -#include "../utils/types.hpp" +#include "ck_tile/builder/conv_signature_types.hpp" int main() { - namespace ckb = ck_tile::builder; - namespace ckb_examples = ck_tile::builder::examples; + using namespace ck_tile::builder; - constexpr ckb_examples::ConvSignature FwdConvSignature + constexpr ConvSignature FwdConvSignature { .spatial_dim = 2, - .direction = ckb::ConvDirection::FORWARD, - .layout = ckb::GroupConvLayout::CHANNELS_LAST, - .data_type = ckb::DataType::BF16, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout::CHANNELS_LAST, + .data_type = DataType::BF16 }; - static_assert(ckb::ValidConvSignature); + static_assert(ValidConvSignature); // To get valid configuration parameters, refer to "device_grouped_conv_fwd_xdl_comp_instance.hpp". // This file contains the current instances of the kernel DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3. // Currently the build has this kernel hard-coded. // In the future, we may need builders per kernel type since they typically have slightly different parameters. - constexpr ckb::ThreadBlock FwdThreadBlock + constexpr ThreadBlock FwdThreadBlock { .block_size = 256, .submatrix = {.m = 256, .n = 256, .k = 32} // Tile sizes }; - constexpr ckb::ConvTuningParams FwdTuningParams + constexpr ConvTuningParams FwdTuningParams { .ak1 = 8, .bk1 = 8, .m_per_xdl=32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4 }; - constexpr ckb_examples::BlockTransfer FwdBlockTransfer + constexpr BlockTransfer FwdBlockTransfer { .thread_cluster_dims_a = {.k0 = 4, .m = 64, .k1 = 1}, .thread_cluster_dims_b = {.k0 = 4, .n = 64, .k1 = 1}, @@ -50,15 +49,15 @@ int main() .b_source_access_order = {1, 0, 2} }; - constexpr ckb_examples::ConvAlgorithm FwdConvAlgorithm + constexpr ConvAlgorithm FwdConvAlgorithm { .thread_block = FwdThreadBlock, .tuning_params = FwdTuningParams, .block_transfer = FwdBlockTransfer, - .pipeline_version = ckb::BlockGemmPipelineVersion::V4, + .pipeline_version = BlockGemmPipelineVersion::V4, }; - using Builder = ckb::ConvBuilder; + using Builder = ConvBuilder; const auto kernel_string = Builder::Instance::TypeString(); std::cout << "Generated kernel: " << kernel_string << std::endl; diff --git a/example/ck_builder/utils/types.hpp b/example/ck_builder/utils/types.hpp deleted file mode 100644 index b3db7c66d2..0000000000 --- a/example/ck_builder/utils/types.hpp +++ /dev/null @@ -1,53 +0,0 @@ -#pragma once - -#include "ck_tile/builder/conv_algorithm.hpp" -#include "ck_tile/builder/conv_signature.hpp" - -namespace ck_tile::builder::examples -{ - namespace ckb = ck_tile::builder; - - struct ConvSignature { - int spatial_dim; - ckb::ConvDirection direction; - ckb::GroupConvLayout layout; - ckb::DataType data_type; - }; - static_assert(ckb::ConvSignatureDescriptor); - - struct BlockTransfer - { - ckb::BlockATransferLengths thread_cluster_dims_a; - ckb::BlockBTransferLengths thread_cluster_dims_b; - ckb::BlockCTransferLengths thread_cluster_dims_c; - ckb::VectorTransferAB vector_transfer_a; - ckb::VectorTransferAB vector_transfer_b; - ckb::VectorTransferC vector_transfer_c; - ckb::ThreadClusterAccessOrder a_thread_cluster_access_order; - ckb::ThreadClusterAccessOrder b_thread_cluster_access_order; - ckb::SourceAccessOrder a_source_access_order; - ckb::SourceAccessOrder b_source_access_order; - }; - - struct ConvAlgorithm - { - ckb::ThreadBlock thread_block; - ckb::ConvTuningParams tuning_params; - BlockTransfer block_transfer; - ckb::BlockGemmPipelineVersion pipeline_version; - }; - static_assert(ckb::ConvAlgorithmDescriptor); - static_assert(ckb::SpecifiesThreadBlock); - static_assert(ckb::SpecifiesConvTuning); - static_assert(ckb::SpecifiesGemmPipelineVersion); - static_assert(ckb::SpecifiesBlockATransfer); - static_assert(ckb::SpecifiesBlockBTransfer); - static_assert(ckb::SpecifiesBlockCTransfer); - static_assert(ckb::SpecifiesBlockAVectorTransfer); - static_assert(ckb::SpecifiesBlockBVectorTransfer); - static_assert(ckb::SpecifiesBlockCVectorTransfer); - static_assert(ckb::SpecifiesAThreadClusterAccessOrder); - static_assert(ckb::SpecifiesBThreadClusterAccessOrder); - static_assert(ckb::SpecifiesASourceAccessOrder); - static_assert(ckb::SpecifiesBSourceAccessOrder); -} // namespace ck_tile::builder::examples diff --git a/experimental/builder/include/ck_tile/builder/builder_utils.hpp b/experimental/builder/include/ck_tile/builder/builder_utils.hpp index a726542829..a2e9d472d2 100644 --- a/experimental/builder/include/ck_tile/builder/builder_utils.hpp +++ b/experimental/builder/include/ck_tile/builder/builder_utils.hpp @@ -67,4 +67,41 @@ struct UnsupportedEnumValue { }; +// Helper functions to convert enums to strings +constexpr std::string_view ConvDirectionToString(ConvDirection dir) +{ + switch(dir) + { + case ConvDirection::FORWARD: return "Forward"; + case ConvDirection::BACKWARD_DATA: return "Backward Data"; + case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight"; + default: return "Unknown"; + } +} + +constexpr std::string_view DataTypeToString(DataType dt) +{ + switch(dt) + { + case DataType::FP16: return "FP16"; + case DataType::FP32: return "FP32"; + case DataType::FP64: return "FP64"; + case DataType::BF16: return "BF16"; + case DataType::S16: return "S16"; + case DataType::S8: return "S8"; + case DataType::S4: return "S4"; + default: return "Unknown"; + } +} + +constexpr std::string_view LayoutToString(GroupConvLayout layout) +{ + switch(layout) + { + case GroupConvLayout::CHANNELS_FIRST: return "Channels-first (NCHW)"; + case GroupConvLayout::CHANNELS_LAST: return "Channels-last (NHWC)"; + default: return "Unknown"; + } +} + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 910bb16469..01e5cebe1a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -3,23 +3,10 @@ #include #include #include +#include "types.hpp" namespace ck_tile::builder { -// TODO: VP (Oct 3, 2025) - Separate the concepts and structs into separate files. -// Concepts the define interface and structs are PODs that implement the concepts. -// The interface is really just the concepts. Clients can define their own structs -// as long as they satisfy the concepts. - -// Convenience struct for a tuple of m, n, and k values. -template -struct MNK -{ - T m{}; - T n{}; - T k{}; -}; - // Concept for thread block dimensions for a GEMM problem. template concept ThreadBlockDescriptor = requires(T t) { @@ -29,16 +16,6 @@ concept ThreadBlockDescriptor = requires(T t) { { t.submatrix.k } -> std::convertible_to; }; -// Specifiy thread block dimensions for a GEMM. -struct ThreadBlock -{ - // Thread block size. - int block_size; - // Size of the submatrix problem in a thread block. - MNK submatrix; -}; -static_assert(ThreadBlockDescriptor); - // Concept to check if struct specifies thread block info. template concept SpecifiesThreadBlock = requires { @@ -56,19 +33,6 @@ concept ConvTuningDescriptor = requires(T t) { { t.n_xdl_per_wave } -> std::convertible_to; }; -// Describe some convolution tuning parameters. -struct ConvTuningParams -{ - // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! - int ak1 = 0; - int bk1 = 0; - int m_per_xdl = 0; - int n_per_xdl = 0; - int m_xdl_per_wave = 0; - int n_xdl_per_wave = 0; -}; -static_assert(ConvTuningDescriptor); - // Concept to check if a struct specifies convolution tuning info. template concept SpecifiesConvTuning = requires { @@ -83,15 +47,6 @@ concept BlockATransferDescriptor = requires(T t) { { t.k1 } -> std::convertible_to; }; -// Describe A block transfer thread cluster lengths. -struct BlockATransferLengths -{ - int k0; - int m; - int k1; -}; -static_assert(BlockATransferDescriptor); - // Concept for B block transfer thread cluster lengths. template concept BlockBTransferDescriptor = requires(T t) { @@ -100,43 +55,18 @@ concept BlockBTransferDescriptor = requires(T t) { { t.k1 } -> std::convertible_to; }; -// Describe B block transfer thread cluster lengths. -struct BlockBTransferLengths -{ - int k0; - int n; - int k1; -}; -static_assert(BlockBTransferDescriptor); - // Concept for the thread cluster access order template concept ThreadClusterAccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; -// Describe the thread cluster access order. -struct ThreadClusterAccessOrder -{ - // Order of the cluster dimensions. Must be a permutation of {0, 1, 2}. - std::array order; -}; -static_assert(ThreadClusterAccessOrderDescriptor); - // Concept to describe source access order template concept SourceAccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; -// Describe the source access order. -struct SourceAccessOrder -{ - // Order of the source dimensions. Must be a permutation of {0, 1, 2}. - std::array order; -}; -static_assert(SourceAccessOrderDescriptor); - // Concept for C block transfer thread cluster lengths. template concept BlockCTransferDescriptor = requires(T t) { @@ -146,16 +76,6 @@ concept BlockCTransferDescriptor = requires(T t) { { t.n_wave_per_xdl } -> std::convertible_to; }; -// Describe C block transfer thread cluster lengths. -struct BlockCTransferLengths -{ - int m_block; - int m_wave_per_xdl; - int n_block; - int n_wave_per_xdl; -}; -static_assert(BlockCTransferDescriptor); - // Concept for vector transfer details for A and B tensors template concept VectorTransferDescriptorAB = requires(T t) { @@ -165,15 +85,6 @@ concept VectorTransferDescriptorAB = requires(T t) { { t.add_extra } -> std::convertible_to; }; -struct VectorTransferAB -{ - size_t src_vector_dim; - size_t src_scaler_per_vector; - size_t dest_scaler_per_vector_k1; - bool add_extra; -}; -static_assert(VectorTransferDescriptorAB); - // Concept for the C tensor vectors transfer details. template concept VectorTransferDescriptorC = requires(T t) { @@ -182,14 +93,6 @@ concept VectorTransferDescriptorC = requires(T t) { { t.scaler_per_vector } -> std::convertible_to; }; -struct VectorTransferC -{ - size_t m_xdl_per_wave_per_shuffle; - size_t n_xdl_per_wave_per_shuffle; - size_t scaler_per_vector; -}; -static_assert(VectorTransferDescriptorC); - // Concept to check if a struct specifies A Block tranfer info. template concept SpecifiesBlockATransfer = requires(T t) { @@ -250,15 +153,6 @@ concept SpecifiesBSourceAccessOrder = requires(T t) { { T::block_transfer.b_source_access_order } -> SourceAccessOrderDescriptor; }; -// Enums for the current block GEMM pipeline versions. -enum class BlockGemmPipelineVersion -{ - V1, - V3, - V4, - V5 -}; - // Concept to check if struct specifies block_gemm_pipeline_version. template concept SpecifiesGemmPipelineVersion = requires { diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_types.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_types.hpp new file mode 100644 index 0000000000..c7d4c265a5 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_types.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include "conv_algorithm_concepts.hpp" + +namespace ck_tile::builder { +// Convenience struct for a tuple of m, n, and k values. +template +struct MNK +{ + T m{}; + T n{}; + T k{}; +}; + +// Specifiy thread block dimensions for a GEMM. +struct ThreadBlock +{ + // Thread block size. + int block_size; + // Size of the submatrix problem in a thread block. + MNK submatrix; +}; +static_assert(ThreadBlockDescriptor); + +// Describe some convolution tuning parameters. +struct ConvTuningParams +{ + // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! + int ak1 = 0; + int bk1 = 0; + int m_per_xdl = 0; + int n_per_xdl = 0; + int m_xdl_per_wave = 0; + int n_xdl_per_wave = 0; +}; +static_assert(ConvTuningDescriptor); + +// Describe A block transfer thread cluster lengths. +struct BlockATransferLengths +{ + int k0; + int m; + int k1; +}; +static_assert(BlockATransferDescriptor); + +// Describe B block transfer thread cluster lengths. +struct BlockBTransferLengths +{ + int k0; + int n; + int k1; +}; +static_assert(BlockBTransferDescriptor); + +// Describe the thread cluster access order. +struct ThreadClusterAccessOrder +{ + // Order of the cluster dimensions. Must be a permutation of {0, 1, 2}. + std::array order; +}; +static_assert(ThreadClusterAccessOrderDescriptor); + +// Describe the source access order. +struct SourceAccessOrder +{ + // Order of the source dimensions. Must be a permutation of {0, 1, 2}. + std::array order; +}; +static_assert(SourceAccessOrderDescriptor); + +// Describe C block transfer thread cluster lengths. +struct BlockCTransferLengths +{ + int m_block; + int m_wave_per_xdl; + int n_block; + int n_wave_per_xdl; +}; +static_assert(BlockCTransferDescriptor); + +struct VectorTransferAB +{ + size_t src_vector_dim; + size_t src_scaler_per_vector; + size_t dest_scaler_per_vector_k1; + bool add_extra; +}; +static_assert(VectorTransferDescriptorAB); + +struct VectorTransferC +{ + size_t m_xdl_per_wave_per_shuffle; + size_t n_xdl_per_wave_per_shuffle; + size_t scaler_per_vector; +}; +static_assert(VectorTransferDescriptorC); + +struct BlockTransfer +{ + BlockATransferLengths thread_cluster_dims_a; + BlockBTransferLengths thread_cluster_dims_b; + BlockCTransferLengths thread_cluster_dims_c; + VectorTransferAB vector_transfer_a; + VectorTransferAB vector_transfer_b; + VectorTransferC vector_transfer_c; + ThreadClusterAccessOrder a_thread_cluster_access_order; + ThreadClusterAccessOrder b_thread_cluster_access_order; + SourceAccessOrder a_source_access_order; + SourceAccessOrder b_source_access_order; +}; + +struct ConvAlgorithm +{ + ThreadBlock thread_block; + ConvTuningParams tuning_params; + BlockTransfer block_transfer; + BlockGemmPipelineVersion pipeline_version; +}; +static_assert(ConvAlgorithmDescriptor); +static_assert(SpecifiesThreadBlock); +static_assert(SpecifiesConvTuning); +static_assert(SpecifiesGemmPipelineVersion); +static_assert(SpecifiesBlockATransfer); +static_assert(SpecifiesBlockBTransfer); +static_assert(SpecifiesBlockCTransfer); +static_assert(SpecifiesBlockAVectorTransfer); +static_assert(SpecifiesBlockBVectorTransfer); +static_assert(SpecifiesBlockCVectorTransfer); +static_assert(SpecifiesAThreadClusterAccessOrder); +static_assert(SpecifiesBThreadClusterAccessOrder); +static_assert(SpecifiesASourceAccessOrder); +static_assert(SpecifiesBSourceAccessOrder); + +} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_builder.hpp b/experimental/builder/include/ck_tile/builder/conv_builder.hpp index 910192c341..c0ca8c4e1b 100644 --- a/experimental/builder/include/ck_tile/builder/conv_builder.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_builder.hpp @@ -3,9 +3,7 @@ #include #include -#include #include -#include #include namespace ck_tile::builder { diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 264604ef5c..edc390baa7 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -33,12 +33,12 @@ #include #include #include -#include +#include #include #include #include -namespace ck_tile::builder { +namespace ck_tile::builder::factory_internal { // Type mappings from the builder GroupConvLayout enum class to the CK tensor data types. template @@ -374,6 +374,10 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() return ck::BlockGemmPipelineVersion::v4; } +} // namespace ck_tile::builder::factory + +namespace ck_tile::builder { + // Primary template for the convolution factory. template { static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = ConvTensorLayouts; - using Types = ConvTensorTypes; - using Ops = ConvPassThroughOps; - static constexpr ConvSpec SPECIALIZATION{ + using Layouts = factory_internal::ConvTensorLayouts; + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ConvPassThroughOps; + static constexpr factory_internal::ConvSpec SPECIALIZATION{ .conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, .gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding, }; - static constexpr ConvBlock BLOCK = SetThreadBlockInfo(); - static constexpr ConvTuning TUNING = SetConvTuningInfo(); - static constexpr BlockTransfer A_BLOCK_TRANSFER = SetFwdConvABlockTransfer(); - static constexpr BlockTransfer B_BLOCK_TRANSFER = SetFwdConvBBlockTransfer(); - static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer(); - static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave; - static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto TUNING = factory_internal::SetConvTuningInfo(); + static constexpr auto A_BLOCK_TRANSFER = factory_internal::SetFwdConvABlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = factory_internal::SetFwdConvBBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); + static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave; + static constexpr auto PIPELINE_VERSION = factory_internal::SetBlockGemmPipelineVersion(); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< // @@ -461,20 +465,20 @@ template { static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = ConvTensorLayouts; - using Types = ConvTensorTypes; - using Ops = ConvPassThroughOps; - static constexpr ConvSpec SPECIALIZATION{ + using Layouts = factory_internal::ConvTensorLayouts; + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ConvPassThroughOps; + static constexpr factory_internal::ConvSpec SPECIALIZATION{ .conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default, .gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding, }; - static constexpr ConvBlock BLOCK = SetThreadBlockInfo(); - static constexpr ConvTuning TUNING = SetConvTuningInfo(); - static constexpr BlockTransfer A_BLOCK_TRANSFER = SetBwdDataConvABlockTransfer(); - static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBwdDataConvBBlockTransfer(); - static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer(); - static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave; - static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto TUNING = factory_internal::SetConvTuningInfo(); + static constexpr auto A_BLOCK_TRANSFER = factory_internal::SetBwdDataConvABlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = factory_internal::SetBwdDataConvBBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer(); + static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave; + static constexpr auto PIPELINE_VERSION = factory_internal::SetBlockGemmPipelineVersion(); // The backward-data convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< diff --git a/experimental/builder/include/ck_tile/builder/conv_signature.hpp b/experimental/builder/include/ck_tile/builder/conv_signature.hpp index 6d09185577..6d2ba4b32a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature.hpp @@ -21,13 +21,6 @@ namespace ck_tile::builder { -// Memory layouts for convolution tensors, following PyTorch conventions. -enum class GroupConvLayout -{ - CHANNELS_LAST, // e.g., NHWGC - CHANNELS_FIRST // e.g., NGCHW -}; - // Constrains convolution to 1D, 2D, or 3D spatial dimensions. template concept ConvSpatialDim = std::is_integral_v && (N == 1 || N == 2 || N == 3); @@ -36,25 +29,6 @@ concept ConvSpatialDim = std::is_integral_v && (N == 1 || N == 2 || template concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16); -// Direction of the convolution operation. -enum class ConvDirection -{ - FORWARD, - BACKWARD_DATA, - BACKWARD_WEIGHT -}; - -// Fused element-wise operations. -enum class ElementwiseOperation -{ - BIAS, - BIAS_CLAMP, - BILINEAR, - CLAMP, - SCALE, - PASS_THROUGH -}; - // Concept for a type that defines a convolution's operational signature. template concept ConvSignatureDescriptor = requires(T t) { @@ -83,41 +57,4 @@ concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_ template concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); -// Helper functions to convert enums to strings -constexpr std::string_view ConvDirectionToString(ConvDirection dir) -{ - switch(dir) - { - case ConvDirection::FORWARD: return "Forward"; - case ConvDirection::BACKWARD_DATA: return "Backward Data"; - case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight"; - default: return "Unknown"; - } -} - -constexpr std::string_view DataTypeToString(DataType dt) -{ - switch(dt) - { - case DataType::FP16: return "FP16"; - case DataType::FP32: return "FP32"; - case DataType::FP64: return "FP64"; - case DataType::BF16: return "BF16"; - case DataType::S16: return "S16"; - case DataType::S8: return "S8"; - case DataType::S4: return "S4"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout layout) -{ - switch(layout) - { - case GroupConvLayout::CHANNELS_FIRST: return "Channels-first (NCHW)"; - case GroupConvLayout::CHANNELS_LAST: return "Channels-last (NHWC)"; - default: return "Unknown"; - } -} - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 6024bf74e7..b27e790189 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -13,4 +13,39 @@ enum class DataType S4, }; +// Memory layouts for convolution tensors, following PyTorch conventions. +enum class GroupConvLayout +{ + CHANNELS_LAST, // e.g., NHWGC + CHANNELS_FIRST // e.g., NGCHW +}; + +// Direction of the convolution operation. +enum class ConvDirection +{ + FORWARD, + BACKWARD_DATA, + BACKWARD_WEIGHT +}; + +// Fused element-wise operations. +enum class ElementwiseOperation +{ + BIAS, + BIAS_CLAMP, + BILINEAR, + CLAMP, + SCALE, + PASS_THROUGH +}; + +// Enums for the current block GEMM pipeline versions. +enum class BlockGemmPipelineVersion +{ + V1, + V3, + V4, + V5 +}; + } // namespace ck_tile::builder