diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 16506e9681..c900bb8c40 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) { // Concept for parameter that describe block GEMM problem. template concept BlockGemmDescriptor = requires(T t) { - { t.pipeline_version } -> std::convertible_to; - { t.scheduler } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; }; // Concept for parameters that describe a gridwise WMMA GEMM problem. @@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { { t.n_per_wmma } -> std::convertible_to; { t.m_wmma_per_wave } -> std::convertible_to; { t.n_wmma_per_wave } -> std::convertible_to; - { t.pipeline_version } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; }; // Concept for vectorized data transfer for convolution input tensors. @@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) { // Concept to check if struct specifies block GEMM. template concept SpecifiesBlockGemm = requires { - { T::block_gemm.pipeline_version } -> std::convertible_to; - { T::block_gemm.scheduler } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; }; template @@ -180,7 +180,7 @@ concept SpecifiesNumGroupsToMerge = requires { template concept SpecifiesLoopScheduler = requires { - { T::loop_scheduler } -> std::convertible_to; + { T::loop_scheduler } -> std::convertible_to; }; /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index c48228fa37..87d2701e01 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -299,23 +299,22 @@ consteval BlockGemmSpec SetBlockGemm() switch(BG.scheduler) { - case BlockGemmPipelineScheduler::INTRAWAVE: - scheduler = ck::BlockGemmPipelineScheduler::Intrawave; - break; - case BlockGemmPipelineScheduler::INTERWAVE: - scheduler = ck::BlockGemmPipelineScheduler::Interwave; - break; - default: throw "Unknown BlockGemmPipelineScheduler"; + case PipelineScheduler::INTRAWAVE: scheduler = ck::BlockGemmPipelineScheduler::Intrawave; break; + case PipelineScheduler::INTERWAVE: scheduler = ck::BlockGemmPipelineScheduler::Interwave; break; + case PipelineScheduler::DEFAULT: throw "Block GEMM scheduler must be Intrawave or Interwave."; + default: throw "Unknown PipelineScheduler"; } switch(BG.pipeline_version) { - case BlockGemmPipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break; - case BlockGemmPipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break; - case BlockGemmPipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; - case BlockGemmPipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; - case BlockGemmPipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; - default: throw "Unknown BlockGemmPipelineVersion"; + case PipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break; + case PipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break; + case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; + case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; + case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM."; + default: throw "Unknown PipelineVersion"; } return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; @@ -427,9 +426,10 @@ consteval ck::LoopScheduler SetLoopScheduler() using ck_loop_sched = ck::LoopScheduler; switch(loop_scheduler) { - case LoopScheduler::DEFAULT: return ck_loop_sched::Default; - case LoopScheduler::INTERWAVE: return ck_loop_sched::Interwave; - default: throw "Unknown LoopScheduler"; + case PipelineScheduler::DEFAULT: return ck_loop_sched::Default; + case PipelineScheduler::INTERWAVE: return ck_loop_sched::Interwave; + case PipelineScheduler::INTRAWAVE: throw "LoopScheduler must be either DEFAULT or INTERWAVE."; + default: throw "Unknown PipelineScheduler"; } } @@ -440,12 +440,12 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() using ck_pipeline = ck::PipelineVersion; switch(pipeline_version) { - case GridwiseGemmPipelineVersion::V1: return ck_pipeline::v1; - case GridwiseGemmPipelineVersion::V2: return ck_pipeline::v2; - case GridwiseGemmPipelineVersion::V4: return ck_pipeline::v4; - case GridwiseGemmPipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; - case GridwiseGemmPipelineVersion::V3: - throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K."; + case PipelineVersion::V1: return ck_pipeline::v1; + case PipelineVersion::V2: return ck_pipeline::v2; + case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K."; + case PipelineVersion::V4: return ck_pipeline::v4; + case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM."; + case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; default: throw "Unknown GridwiseGemmPipelineVersion"; } } @@ -482,15 +482,15 @@ template consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { constexpr auto version = ALGORITHM.pipeline_version; - using ck_block_gemm = ck::BlockGemmPipelineVersion; + using ck_pipeline = ck::BlockGemmPipelineVersion; switch(version) { - case BlockGemmPipelineVersion::V1: return ck_block_gemm::v1; - case BlockGemmPipelineVersion::V2: return ck_block_gemm::v2; - case BlockGemmPipelineVersion::V3: return ck_block_gemm::v3; - case BlockGemmPipelineVersion::V4: return ck_block_gemm::v4; - case BlockGemmPipelineVersion::V5: return ck_block_gemm::v5; - default: throw "Unknown BlockGemmPipelineVersion"; + case PipelineVersion::V1: return ck_pipeline::v1; + case PipelineVersion::V2: return ck_pipeline::v2; + case PipelineVersion::V3: return ck_pipeline::v3; + case PipelineVersion::V4: return ck_pipeline::v4; + case PipelineVersion::V5: return ck_pipeline::v5; + default: throw "Unknown block GEMM PipelineVersion"; } } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp new file mode 100644 index 0000000000..a74d77d155 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -0,0 +1,719 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile::reflect::conv { + +// Helper metafunctions to convert from ck enums to builder enums + +/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5). +/// @details This function maps CK's block GEMM pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. The pipeline version +/// determines the strategy used for data movement and computation overlap in the +/// GEMM kernel's main loop. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::BlockGemmPipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v3) + return V3; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == v5) + return V5; +} + +/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK PipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY). +/// @details This function maps CK's general pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. Note that this overload +/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion +/// variant, including support for specialized weight-only pipelines. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::PipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == weight_only) + return WEIGHT_ONLY; +} + +/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE). +/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the +/// builder framework's standardized scheduler enum. The scheduler determines how work +/// is distributed and synchronized within and across wavefronts during pipeline execution. +/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates +/// across multiple wavefronts. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::BlockGemmPipelineScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Intrawave) + return INTRAWAVE; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + +/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK LoopScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE). +/// @details This function maps CK's loop scheduler identifiers to the builder framework's +/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of +/// the main computational loop are scheduled across threads. DEFAULT uses the standard +/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved +/// performance in certain scenarios. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::LoopScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Default) + return DEFAULT; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + +/// @brief Helper structures for organizing trait data with domain-specific naming + +/// @brief Data tile dimensions processed by a workgroup. +/// @details This struct defines the M, N, and K dimensions of the data tile +/// that a single workgroup (thread block) is responsible for processing in the +/// underlying GEMM computation. +struct DataTileInfo +{ + int m; ///< M dimension of the tile processed by the workgroup (MPerBlock). + int n; ///< N dimension of the tile processed by the workgroup (NPerBlock). + int k; ///< K dimension of the tile processed by the workgroup (KPerBlock). +}; + +/// @brief Dimensions for an input data tile transfer. +/// @details Defines the shape of the input tile (A or B matrix) as it is +/// transferred from global memory to LDS. The tile is conceptually divided +/// into k0 and k1 dimensions. +struct InputTileTransferDimensions +{ + int k0; ///< The outer dimension of K, where K = k0 * k1. + int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix. + int k1; ///< The inner dimension of K, often corresponding to the vector load size from global + ///< memory. +}; + +/// @brief Parameters governing the transfer of an input tile. +/// @details This struct holds configuration details for how an input tile is +/// loaded from global memory into LDS, including thread clustering, memory +/// access patterns, and vectorization settings. +struct InputTileTransferParams +{ + int k1; ///< The inner K dimension size, often matching the vectorization width. + std::array + thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how + ///< many threads are arranged on each axis. + std::array thread_cluster_order; ///< The order of thread spatial distribution over the + ///< input tensor dimensions. + std::array src_access_order; ///< The order of accessing input tensor axes (e.g., which + ///< dimension to read first). + int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed + ///< (the contiguous dimension). + int src_scalar_per_vector; ///< The size of the vector access instruction; the number of + ///< elements accessed per thread per instruction. + int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1 + ///< dimension. + bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank + ///< conflicts. +}; + +/// @brief Complete information for an input tile transfer. +/// @details Combines the dimensional information and transfer parameters for +/// a full description of an input tile's journey from global memory to LDS. +struct InputTileTransferInfo +{ + InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile. + InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation. +}; + +/// @brief Parameters for the warp-level GEMM computation. +/// @details Defines the configuration of the GEMM operation performed by each +/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions. +struct WarpGemmParams +{ + int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl). + int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl). + int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per + ///< wavefront (MXdlPerWave). + int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per + ///< wavefront (NXdlPerWave). +}; + +/// @brief Parameters for shuffling data between warps (CShuffle optimization). +/// @details Configures how many MFMA instruction results are processed per +/// wave in each iteration of the CShuffle routine. +struct WarpShuffleParams +{ + int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave + ///< per shuffle iteration. + int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave + ///< per shuffle iteration. +}; + +/// @brief Information for the output tile transfer (CShuffle). +/// @details Describes how the final computed tile (C matrix) is written out from +/// LDS to global memory, including shuffling, thread clustering, and vectorization. +struct OutputTileTransferInfo +{ + WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling. + // m_block, m_wave_per_xdl, n_block, n_wave_per_xdl + std::array thread_cluster_dims; ///< The spatial thread distribution used for storing + ///< data into the output tensor. + int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the + ///< output tensor. +}; + +// Helper metafunctions to derive signature information from Instance types + +/// @brief Derives the convolution direction from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). +template +constexpr builder::ConvDirection conv_direction() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) + { + return builder::ConvDirection::FORWARD; + } + else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) + { + return builder::ConvDirection::BACKWARD_DATA; + } + else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) + { + return builder::ConvDirection::BACKWARD_WEIGHT; + } + else + { + return builder::ConvDirection::FORWARD; // Default fallback + } +} + +/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or +/// `builder::ConvBwdWeightSpecialization` enum value. +template +constexpr auto conv_spec() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { InstTraits::kConvForwardSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + + if constexpr(InstTraits::kConvForwardSpecialization == Default) + { + return builder::ConvFwdSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3) + { + return builder::ConvFwdSpecialization::FILTER_3x3; + } + } + else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + + if constexpr(InstTraits::kConvBwdDataSpecialization == Default) + { + return builder::ConvBwdDataSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + } + else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + + if constexpr(InstTraits::kConvBwdWeightSpecialization == Default) + { + return builder::ConvBwdWeightSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC) + { + return builder::ConvBwdWeightSpecialization::ODD_C; + } + } +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts. +template +constexpr auto conv_layout() +{ + using InstTraits = InstanceTraits; + using ALayout = typename InstTraits::ALayout; + using BLayout = typename InstTraits::BLayout; + using ELayout = typename InstTraits::ELayout; + + namespace ctc = ck::tensor_layout::convolution; + + if constexpr(InstTraits::kSpatialDim == 1) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout1D::GNWC_GKXC_GNWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NWGC_GKXC_NWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKXC_NGKW; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKCX_NGKW; + } + } + else if constexpr(InstTraits::kSpatialDim == 2) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW; + } + } + else if constexpr(InstTraits::kSpatialDim == 3) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW; + } + } +} + +/// @brief Derives the data type from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). +template +constexpr builder::DataType conv_data_type() +{ + using InstTraits = InstanceTraits; + using ADataType = typename InstTraits::ADataType; + + if constexpr(std::is_same_v) + { + return builder::DataType::FP16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::BF16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP32; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::I8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::U8; + } + else + { + // Default fallback + return builder::DataType::FP32; + } +} + +/// @brief Derives the elementwise operation from op type. +/// @tparam ElementwiseOp Elementwise operation functor type. +/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation. +template +constexpr builder::ElementwiseOperation elementwise_op() +{ + constexpr std::string_view name = detail::elementwise_op_name(); + if constexpr(detail::case_insensitive_equal(name, "Bias")) + { + return builder::ElementwiseOperation::BIAS; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasClamp")) + { + return builder::ElementwiseOperation::BIAS_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + { + return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) + { + return builder::ElementwiseOperation::BILINEAR; + } + else if constexpr(detail::case_insensitive_equal(name, "Clamp")) + { + return builder::ElementwiseOperation::CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Scale")) + { + return builder::ElementwiseOperation::SCALE; + } + else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + { + return builder::ElementwiseOperation::PASS_THROUGH; + } +} + +/// @brief Derives a gemm padding from a kernel instance type. +/// @tparam Instance - A Device Kernel object type. +/// @return A `builder::GemmPadding` enum value corresponding to kernel padding. +template +constexpr builder::GemmPadding gemm_spec() +{ + using InstTraits = InstanceTraits; + using enum builder::GemmPadding; + using enum ck::tensor_operation::device::GemmSpecialization; + + constexpr auto gemm_spec = InstTraits::kGemmSpecialization; + + if constexpr(gemm_spec == Default) + { + return DEFAULT; + } + else if constexpr(gemm_spec == MPadding) + { + return M_PADDING; + } + else if constexpr(gemm_spec == NPadding) + { + return N_PADDING; + } + else if constexpr(gemm_spec == KPadding) + { + return K_PADDING; + } + else if constexpr(gemm_spec == MNPadding) + { + return MN_PADDING; + } + else if constexpr(gemm_spec == MKPadding) + { + return MK_PADDING; + } + else if constexpr(gemm_spec == NKPadding) + { + return NK_PADDING; + } + else if constexpr(gemm_spec == MNKPadding) + { + return MNK_PADDING; + } + else if constexpr(gemm_spec == OPadding) + { + return O_PADDING; + } + else if constexpr(gemm_spec == MOPadding) + { + return MO_PADDING; + } + else if constexpr(gemm_spec == NOPadding) + { + return NO_PADDING; + } + else if constexpr(gemm_spec == KOPadding) + { + return KO_PADDING; + } + else if constexpr(gemm_spec == MNOPadding) + { + return MNO_PADDING; + } + else if constexpr(gemm_spec == MKOPadding) + { + return MKO_PADDING; + } + else if constexpr(gemm_spec == NKOPadding) + { + return NKO_PADDING; + } + else if constexpr(gemm_spec == MNKOPadding) + { + return MNKO_PADDING; + } +} + +/// @brief Primary template for extracting convolution traits. +/// @details This struct is the main entry point for reflecting on a convolution +/// kernel's properties. It is specialized to handle different kinds of input types. +template +struct ConvTraits; + +/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`. +/// @details This is the primary specialization used to extract a comprehensive +/// set of traits directly from a fully-formed device kernel `Instance` type. +/// It uses `InstanceTraits` to access the kernel's template parameters. +template + requires requires { typename InstanceTraits; } +struct ConvTraits +{ + using InstTraits = InstanceTraits; + + // --- Signature Information --- + /// @brief The number of spatial dimensions in the convolution (1, 2, or 3). + static constexpr int spatial_dim = InstTraits::kSpatialDim; + /// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight). + static constexpr builder::ConvDirection direction = conv_direction(); + /// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK). + static constexpr auto layout = conv_layout(); + /// @brief The primary data type used in the computation (e.g., FP16, FP32). + static constexpr builder::DataType data_type = conv_data_type(); + + static constexpr builder::ElementwiseOperation input_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation weight_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation output_element_op = + elementwise_op(); + + /// @brief The GEMM specialization used by the kernel - padding + static constexpr auto gemm_padding = gemm_spec(); + /// @brief The convolution-specific specialization (e.g., Default, 1x1). + static constexpr auto conv_specialization = conv_spec(); + + // --- Algorithm Information --- + /// @brief The total number of threads in a thread block (workgroup). + static constexpr int thread_block_size = InstTraits::kBlockSize; + /// @brief The dimensions of the data tile processed by the thread block. + static constexpr DataTileInfo tile_dims = { + .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; + + /// @brief Configuration for the A-matrix (input) tile transfer. + static constexpr InputTileTransferInfo a_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; + + /// @brief Configuration for the B-matrix (weights) tile transfer. + static constexpr InputTileTransferInfo b_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; + + /// @brief Parameters for the warp-level GEMM computation. + static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}; + + /// @brief Configuration for the C-matrix (output) tile transfer. + static constexpr OutputTileTransferInfo c_tile_transfer = { + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; + + /// @brief Helper to safely get the pipeline version. + /// @details This is only available for some convolutions (e.g., forward). + /// If not present in `InstanceTraits`, it returns a default value. + template + static constexpr auto get_pipeline_version() + { + if constexpr(requires { T::kPipelineVersion; }) + { + return convert_pipeline_version(); + } + else + { + // Return a default or indicate not available + return builder::PipelineVersion::V1; + } + } + + /// @brief The block GEMM pipeline version used by the kernel. + static constexpr auto pipeline_version = get_pipeline_version(); + + /// @brief Helper to safely get the pipeline scheduler. + /// @details This is only available for some convolutions. If not present + /// in `InstanceTraits`, it returns a default value. + template + static constexpr auto get_pipeline_scheduler() + { + if constexpr(requires { T::kPipelineScheduler; }) + { + return convert_pipeline_scheduler(); + } + else if constexpr(requires { T::kLoopScheduler; }) + { + return convert_pipeline_scheduler(); + } + else + { + // Return a default or indicate not available + return builder::PipelineScheduler::DEFAULT; + } + } + + /// @brief The pipeline scheduler used by the kernel. + static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); +}; + +/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type. +/// @details This specialization provides backward compatibility for reflecting +/// on kernels defined via the `ConvBuilder` interface. It works by first +/// creating the `Instance` via the builder's factory, and then delegating +/// all trait extraction to the `ConvTraits` specialization. +template +struct ConvTraits> +{ + using Factory = builder::ConvFactory; + using Instance = typename Factory::Instance; + + // Delegate to Instance-based ConvTraits + using InstanceConvTraits = ConvTraits; + + // Forward all members from Instance-based traits + static constexpr int spatial_dim = InstanceConvTraits::spatial_dim; + static constexpr builder::ConvDirection direction = InstanceConvTraits::direction; + static constexpr auto layout = InstanceConvTraits::layout; + static constexpr builder::DataType data_type = InstanceConvTraits::data_type; + + static constexpr builder::ElementwiseOperation input_element_op = + InstanceConvTraits::input_element_op; + static constexpr builder::ElementwiseOperation weight_element_op = + InstanceConvTraits::weight_element_op; + static constexpr builder::ElementwiseOperation output_element_op = + InstanceConvTraits::output_element_op; + + static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding; + static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization; + + static constexpr int thread_block_size = InstanceConvTraits::thread_block_size; + static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims; + static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer; + static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer; + static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm; + static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer; + static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version; + static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler; +}; + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp index c9b45691cc..07f1b94b07 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp @@ -14,18 +14,9 @@ #pragma once -#include #include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include "instance_traits_util.hpp" +#include namespace ck_tile::reflect { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 7eb4c883b0..d6fc6da0d6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -15,6 +15,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 60f991b1fc..9edfa4d4c9 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -14,6 +14,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 95d1c94de4..c863d2306c 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -9,9 +9,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -371,4 +373,30 @@ constexpr std::string type_or_type_tuple_name() } } +/// @brief Makes a case insensitive comparison of two string views. +/// @param a First string view +/// @param b Second string view +/// @return Whether two string views a equal case insensitive +constexpr bool case_insensitive_equal(std::string_view a, std::string_view b) +{ + if(a.size() != b.size()) + return false; + + for(size_t i = 0; i < a.size(); ++i) + { + char c1 = a[i]; + char c2 = b[i]; + + // Convert to lowercase for comparison + if(c1 >= 'A' && c1 <= 'Z') + c1 += 32; + if(c2 >= 'A' && c2 <= 'Z') + c2 += 32; + + if(c1 != c2) + return false; + } + return true; +} + } // namespace ck_tile::reflect::detail diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index f09d740d20..2738452186 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -82,29 +82,14 @@ enum class ElementwiseOperation PASS_THROUGH }; -// Enums for the current block GEMM pipeline versions. -enum class BlockGemmPipelineVersion +// Enums for pipeline versions & schedulers +enum class PipelineVersion { V1, V2, V3, V4, - V5 -}; - -enum struct BlockGemmPipelineScheduler -{ - INTRAWAVE, - INTERWAVE, -}; - -// Enums for the gridwise GEMM pipeline versions. -enum class GridwiseGemmPipelineVersion -{ - V1, - V2, - V3, // Only used in stream-K implementation - V4, + V5, WEIGHT_ONLY }; @@ -140,9 +125,47 @@ enum class ConvFwdSpecialization FILTER_3x3 }; -enum class LoopScheduler +// Enums for the backward data convolution specialization. +enum class ConvBwdDataSpecialization { DEFAULT, + FILTER_1X1_STRIDE1_PAD0, +}; + +// Enums for the backward weight convolution specialization. +enum class ConvBwdWeightSpecialization +{ + DEFAULT, + FILTER_1X1_STRIDE1_PAD0, + FILTER_1X1_PAD0, + ODD_C, +}; + +// Enums for the Gemm padding. +enum class GemmPadding +{ + DEFAULT, + M_PADDING, + N_PADDING, + K_PADDING, + MN_PADDING, + MK_PADDING, + NK_PADDING, + MNK_PADDING, + O_PADDING, + MO_PADDING, + NO_PADDING, + KO_PADDING, + MNO_PADDING, + MKO_PADDING, + NKO_PADDING, + MNKO_PADDING, +}; + +enum class PipelineScheduler +{ + DEFAULT, + INTRAWAVE, INTERWAVE }; diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 90a5de6d66..15e1428419 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -65,6 +65,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) +add_ck_builder_test(test_conv_traits + conv/test_conv_traits.cpp) + # Function to add all test_ckb targets to a list function(collect_test_ckb_targets result_var) # Get all targets in current directory diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp index 330db8d457..e2f66b10ad 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp @@ -27,7 +27,7 @@ TEST(FwdConvInstances, .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, .num_groups_to_merge = 2, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; run_test( diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp index 1ec5bbb349..9573e14264 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, .fwd_specialization = ConvFwdSpecialization::DEFAULT, .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; run_test( diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/test_conv_traits.cpp new file mode 100644 index 0000000000..ca453d2ad4 --- /dev/null +++ b/experimental/builder/test/conv/test_conv_traits.cpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +using ::testing::ElementsAre; + +// Test fixture for ConvTraits tests +class ConvTraitsTest : public ::testing::Test +{ +}; + +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::PipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + false>; // DirectLoad + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(Traits::warp_gemm.gemm_m, 32); + EXPECT_EQ(Traits::warp_gemm.gemm_n, 32); + EXPECT_EQ(Traits::warp_gemm.m_iter, 4); + EXPECT_EQ(Traits::warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE); + EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} +// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} +} // anonymous namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index e719db89ed..26adf46706 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -49,14 +49,14 @@ struct GridwiseWmmaGemm size_t n_per_wmma = 0; size_t m_wmma_per_wave = 0; size_t n_wmma_per_wave = 0; - GridwiseGemmPipelineVersion pipeline_version; + PipelineVersion pipeline_version; }; static_assert(ckb::GridwiseWmmaGemmDescriptor); struct BlockGemm { - BlockGemmPipelineVersion pipeline_version; - BlockGemmPipelineScheduler scheduler; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; }; static_assert(ckb::BlockGemmDescriptor); @@ -136,7 +136,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; size_t num_groups_to_merge; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle @@ -147,7 +147,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ConvFwdSpecialization fwd_specialization; GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 16c667b64b..017af87ab6 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -93,13 +93,12 @@ constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}; -constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, - .m_per_wmma = 32, - .n_per_wmma = 32, - .m_wmma_per_wave = 2, - .n_wmma_per_wave = 1, - .pipeline_version = - GridwiseGemmPipelineVersion::V1}; +constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, + .m_per_wmma = 32, + .n_per_wmma = 32, + .m_wmma_per_wave = 2, + .n_wmma_per_wave = 1, + .pipeline_version = PipelineVersion::V1}; constexpr ThreadBlock FwdThreadBlock_256x256x32{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -113,24 +112,19 @@ constexpr ThreadBlock FwdThreadBlock_64x32x32{.block_size = 64, constexpr ThreadBlock FwdThreadBlock_64x64x64{.block_size = 128, .tile_size = {.m = 64, .n = 64, .k = 64}}; -constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V1, - .scheduler = - BlockGemmPipelineScheduler::INTRAWAVE}; +constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V2, - .scheduler = - BlockGemmPipelineScheduler::INTRAWAVE}; +constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V3, - .scheduler = - BlockGemmPipelineScheduler::INTRAWAVE}; +constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V4, - .scheduler = - BlockGemmPipelineScheduler::INTRAWAVE}; +constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; -constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V5, - .scheduler = - BlockGemmPipelineScheduler::INTRAWAVE}; +constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; } // namespace ck_tile::builder::test_utils diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp index 01bb806789..219206c5ce 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -3,6 +3,8 @@ #pragma once +#include + namespace ck { namespace tensor_operation { namespace device { diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp old mode 100644 new mode 100755 index 8620e7337c..5bf8548470 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -136,66 +136,103 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) #endif } -// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html +struct WaitcntLayoutGfx12 +{ // s_wait_loadcnt_dscnt: mem[13:8], ds[5:0] + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // mem + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // ds + CK_TILE_DEVICE static constexpr bool HAS_EXP = false; + + CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 8); } + CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 0); } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; } +}; + +struct WaitcntLayoutGfx11 +{ // vm[15:10] (6), lgkm[9:4] (6), exp unused + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; + CK_TILE_DEVICE static constexpr bool HAS_EXP = false; + + CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); } + CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; } +}; + +struct WaitcntLayoutLegacy +{ // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2 + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8] + CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4] + CK_TILE_DEVICE static constexpr bool HAS_EXP = true; + + CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) + { + c &= VM_MASK; + return ((c & 0xF) << 0) | ((c & 0x30) << 10); + } + CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 8); } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return ((c & EXP_MASK) << 4); } +}; + +// Select active layout +#if defined(__gfx12__) +using Waitcnt = WaitcntLayoutGfx12; +#elif defined(__gfx11__) +using Waitcnt = WaitcntLayoutGfx11; +#else +using Waitcnt = WaitcntLayoutLegacy; +#endif + +//---------------------------------------------- +// Public API: only from_* (constexpr templates) +//---------------------------------------------- struct waitcnt_arg { -#if defined(__gfx12__) - // use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8] - CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111; - - CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111; - CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111; - CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111; - - template - CK_TILE_DEVICE static constexpr index_t from_vmcnt() - { - static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); - return MAX & (cnt << 8); - } - - template - CK_TILE_DEVICE static constexpr index_t from_expcnt() - { - return 0; // no export in MI series - } - - template - CK_TILE_DEVICE static constexpr index_t from_lgkmcnt() - { - static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); - return MAX & cnt; - } + // kMax* exposed for callers; match field widths per-arch +#if defined(__gfx12__) || defined(__gfx11__) + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none #else - // bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210 - // [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV - CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111; - - CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111; - CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111; - CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111; + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split) + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits +#endif template CK_TILE_DEVICE static constexpr index_t from_vmcnt() { - static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); - return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10)); - } - - template - CK_TILE_DEVICE static constexpr index_t from_expcnt() - { - static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]"); - return MAX & (cnt << 4); + static_assert((cnt & ~Waitcnt::VM_MASK) == 0, "vmcnt out of range"); + return Waitcnt::pack_vm(cnt); } template CK_TILE_DEVICE static constexpr index_t from_lgkmcnt() { - static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]"); - return MAX & (cnt << 8); + static_assert((cnt & ~Waitcnt::LGKM_MASK) == 0, "lgkmcnt out of range"); + return Waitcnt::pack_lgkm(cnt); } + + template + CK_TILE_DEVICE static constexpr index_t from_expcnt() + { + if constexpr(Waitcnt::HAS_EXP) + { + // EXP_MASK only exists on legacy +#if !defined(__gfx12__) && !defined(__gfx11__) + static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range"); + return Waitcnt::pack_exp(cnt); +#else + (void)cnt; + return 0; #endif + } + else + { + static_assert(cnt == 0, "expcnt unsupported on this arch"); + return 0; + } + } }; template > {} -#define MAKE_SC_WITH(start_, step_) \ - ck_tile::static_counter, start_, step_> {} -#define NEXT_SC(c_) c_.next<__COUNTER__>() -#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>() + __extension__ ck_tile::static_counter> {} +#define MAKE_SC_WITH(start_, step_) \ + __extension__ ck_tile:: \ + static_counter, start_, step_> \ + { \ + } +#define NEXT_SC(c_) __extension__ c_.next<__COUNTER__>() +#define NEXT_SCI(c_, static_i_) __extension__ c_.next<__COUNTER__ + static_i_>() // Usage: // constexpr auto c = MAKE_SC() diff --git a/profiler/src/profiler_operation_registry.hpp b/profiler/src/profiler_operation_registry.hpp index 276b7b38dc..7e6d22d4ce 100644 --- a/profiler/src/profiler_operation_registry.hpp +++ b/profiler/src/profiler_operation_registry.hpp @@ -74,6 +74,6 @@ class ProfilerOperationRegistry final #define PP_CONCAT(x, y) PP_CONCAT_IMPL(x, y) #define PP_CONCAT_IMPL(x, y) x##y -#define REGISTER_PROFILER_OPERATION(name, description, operation) \ - static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \ +#define REGISTER_PROFILER_OPERATION(name, description, operation) \ + __extension__ static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \ ::ProfilerOperationRegistry::GetInstance().Add(name, description, operation)