From b8527a92360496666ed6606e53ddc97e35dcf76e Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Wed, 5 Nov 2025 17:53:06 +0100 Subject: [PATCH 1/9] [CK_BUILDER] Convolution traits. (#3152) Added: 1. Convolution traits & unit tests 2. Update builder enumerators to have representation of Convolution Kernels properties. 3. Unified builder pipeline version & scheduler enumerators --- .../builder/conv_algorithm_concepts.hpp | 12 +- .../include/ck_tile/builder/conv_factory.hpp | 48 +- .../ck_tile/builder/reflect/conv_traits.hpp | 719 ++++++++++++++++++ .../builder/reflect/instance_traits.hpp | 11 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 1 + ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 1 + .../builder/reflect/instance_traits_util.hpp | 28 + .../builder/include/ck_tile/builder/types.hpp | 61 +- experimental/builder/test/CMakeLists.txt | 3 + .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 2 +- .../builder/test/conv/test_conv_traits.cpp | 316 ++++++++ .../test/impl/conv_algorithm_types.hpp | 10 +- .../test/utils/ckb_conv_test_common.hpp | 18 +- ...olution_backward_weight_specialization.hpp | 2 + 20 files changed, 1165 insertions(+), 81 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp create mode 100644 experimental/builder/test/conv/test_conv_traits.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 365835684e..e43f910a73 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) { // Concept for parameter that describe block GEMM problem. template concept BlockGemmDescriptor = requires(T t) { - { t.pipeline_version } -> std::convertible_to; - { t.scheduler } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; }; // Concept for parameters that describe a gridwise WMMA GEMM problem. @@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { { t.n_per_wmma } -> std::convertible_to; { t.m_wmma_per_wave } -> std::convertible_to; { t.n_wmma_per_wave } -> std::convertible_to; - { t.pipeline_version } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; }; // Concept for vectorized data transfer for convolution input tensors. @@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) { // Concept to check if struct specifies block GEMM. template concept SpecifiesBlockGemm = requires { - { T::block_gemm.pipeline_version } -> std::convertible_to; - { T::block_gemm.scheduler } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; }; template @@ -180,7 +180,7 @@ concept SpecifiesNumGroupsToMerge = requires { template concept SpecifiesLoopScheduler = requires { - { T::loop_scheduler } -> std::convertible_to; + { T::loop_scheduler } -> std::convertible_to; }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index a3932f524c..1ccc190ba2 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -297,42 +297,42 @@ constexpr BlockGemmSpec SetBlockGemm() ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; - if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE) + if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE) { scheduler = ck::BlockGemmPipelineScheduler::Intrawave; } - else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE) + else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE) { scheduler = ck::BlockGemmPipelineScheduler::Interwave; } else { - static_assert(false, "Unknown BlockGemmPipelineScheduler"); + static_assert(false, "Unknown PipelineScheduler"); } - if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1) + if constexpr(BG.pipeline_version == PipelineVersion::V1) { version = ck::BlockGemmPipelineVersion::v1; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2) + else if constexpr(BG.pipeline_version == PipelineVersion::V2) { version = ck::BlockGemmPipelineVersion::v2; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3) + else if constexpr(BG.pipeline_version == PipelineVersion::V3) { version = ck::BlockGemmPipelineVersion::v3; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4) + else if constexpr(BG.pipeline_version == PipelineVersion::V4) { version = ck::BlockGemmPipelineVersion::v4; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5) + else if constexpr(BG.pipeline_version == PipelineVersion::V5) { version = ck::BlockGemmPipelineVersion::v5; } else { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; @@ -442,17 +442,17 @@ consteval ck::LoopScheduler SetLoopScheduler() { constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; - if constexpr(loop_scheduler == LoopScheduler::DEFAULT) + if constexpr(loop_scheduler == PipelineScheduler::DEFAULT) { return ck::LoopScheduler::Default; } - else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE) + else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE) { return ck::LoopScheduler::Interwave; } else { - static_assert(false, "Unknown LoopScheduler"); + static_assert(false, "Unknown PipelineScheduler"); } } @@ -460,29 +460,29 @@ template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; - if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1) + if constexpr(pipeline_version == PipelineVersion::V1) { return ck::PipelineVersion::v1; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2) + else if constexpr(pipeline_version == PipelineVersion::V2) { return ck::PipelineVersion::v2; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3) + else if constexpr(pipeline_version == PipelineVersion::V3) { static_assert(false, "V3 is used only for stream-K."); } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4) + else if constexpr(pipeline_version == PipelineVersion::V4) { return ck::PipelineVersion::v4; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY) + else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY) { return ck::PipelineVersion::weight_only; } else { - static_assert(false, "Unknown GridwiseGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } } @@ -566,29 +566,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { constexpr auto version = ALGORITHM.pipeline_version; - if constexpr(version == BlockGemmPipelineVersion::V1) + if constexpr(version == PipelineVersion::V1) { return ck::BlockGemmPipelineVersion::v1; } - else if constexpr(version == BlockGemmPipelineVersion::V2) + else if constexpr(version == PipelineVersion::V2) { return ck::BlockGemmPipelineVersion::v2; } - else if constexpr(version == BlockGemmPipelineVersion::V3) + else if constexpr(version == PipelineVersion::V3) { return ck::BlockGemmPipelineVersion::v3; } - else if constexpr(version == BlockGemmPipelineVersion::V4) + else if constexpr(version == PipelineVersion::V4) { return ck::BlockGemmPipelineVersion::v4; } - else if constexpr(version == BlockGemmPipelineVersion::V5) + else if constexpr(version == PipelineVersion::V5) { return ck::BlockGemmPipelineVersion::v5; } else { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp new file mode 100644 index 0000000000..a74d77d155 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -0,0 +1,719 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile::reflect::conv { + +// Helper metafunctions to convert from ck enums to builder enums + +/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5). +/// @details This function maps CK's block GEMM pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. The pipeline version +/// determines the strategy used for data movement and computation overlap in the +/// GEMM kernel's main loop. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::BlockGemmPipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v3) + return V3; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == v5) + return V5; +} + +/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK PipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY). +/// @details This function maps CK's general pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. Note that this overload +/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion +/// variant, including support for specialized weight-only pipelines. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::PipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == weight_only) + return WEIGHT_ONLY; +} + +/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE). +/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the +/// builder framework's standardized scheduler enum. The scheduler determines how work +/// is distributed and synchronized within and across wavefronts during pipeline execution. +/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates +/// across multiple wavefronts. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::BlockGemmPipelineScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Intrawave) + return INTRAWAVE; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + +/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK LoopScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE). +/// @details This function maps CK's loop scheduler identifiers to the builder framework's +/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of +/// the main computational loop are scheduled across threads. DEFAULT uses the standard +/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved +/// performance in certain scenarios. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::LoopScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Default) + return DEFAULT; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + +/// @brief Helper structures for organizing trait data with domain-specific naming + +/// @brief Data tile dimensions processed by a workgroup. +/// @details This struct defines the M, N, and K dimensions of the data tile +/// that a single workgroup (thread block) is responsible for processing in the +/// underlying GEMM computation. +struct DataTileInfo +{ + int m; ///< M dimension of the tile processed by the workgroup (MPerBlock). + int n; ///< N dimension of the tile processed by the workgroup (NPerBlock). + int k; ///< K dimension of the tile processed by the workgroup (KPerBlock). +}; + +/// @brief Dimensions for an input data tile transfer. +/// @details Defines the shape of the input tile (A or B matrix) as it is +/// transferred from global memory to LDS. The tile is conceptually divided +/// into k0 and k1 dimensions. +struct InputTileTransferDimensions +{ + int k0; ///< The outer dimension of K, where K = k0 * k1. + int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix. + int k1; ///< The inner dimension of K, often corresponding to the vector load size from global + ///< memory. +}; + +/// @brief Parameters governing the transfer of an input tile. +/// @details This struct holds configuration details for how an input tile is +/// loaded from global memory into LDS, including thread clustering, memory +/// access patterns, and vectorization settings. +struct InputTileTransferParams +{ + int k1; ///< The inner K dimension size, often matching the vectorization width. + std::array + thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how + ///< many threads are arranged on each axis. + std::array thread_cluster_order; ///< The order of thread spatial distribution over the + ///< input tensor dimensions. + std::array src_access_order; ///< The order of accessing input tensor axes (e.g., which + ///< dimension to read first). + int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed + ///< (the contiguous dimension). + int src_scalar_per_vector; ///< The size of the vector access instruction; the number of + ///< elements accessed per thread per instruction. + int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1 + ///< dimension. + bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank + ///< conflicts. +}; + +/// @brief Complete information for an input tile transfer. +/// @details Combines the dimensional information and transfer parameters for +/// a full description of an input tile's journey from global memory to LDS. +struct InputTileTransferInfo +{ + InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile. + InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation. +}; + +/// @brief Parameters for the warp-level GEMM computation. +/// @details Defines the configuration of the GEMM operation performed by each +/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions. +struct WarpGemmParams +{ + int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl). + int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl). + int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per + ///< wavefront (MXdlPerWave). + int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per + ///< wavefront (NXdlPerWave). +}; + +/// @brief Parameters for shuffling data between warps (CShuffle optimization). +/// @details Configures how many MFMA instruction results are processed per +/// wave in each iteration of the CShuffle routine. +struct WarpShuffleParams +{ + int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave + ///< per shuffle iteration. + int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave + ///< per shuffle iteration. +}; + +/// @brief Information for the output tile transfer (CShuffle). +/// @details Describes how the final computed tile (C matrix) is written out from +/// LDS to global memory, including shuffling, thread clustering, and vectorization. +struct OutputTileTransferInfo +{ + WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling. + // m_block, m_wave_per_xdl, n_block, n_wave_per_xdl + std::array thread_cluster_dims; ///< The spatial thread distribution used for storing + ///< data into the output tensor. + int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the + ///< output tensor. +}; + +// Helper metafunctions to derive signature information from Instance types + +/// @brief Derives the convolution direction from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). +template +constexpr builder::ConvDirection conv_direction() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) + { + return builder::ConvDirection::FORWARD; + } + else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) + { + return builder::ConvDirection::BACKWARD_DATA; + } + else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) + { + return builder::ConvDirection::BACKWARD_WEIGHT; + } + else + { + return builder::ConvDirection::FORWARD; // Default fallback + } +} + +/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or +/// `builder::ConvBwdWeightSpecialization` enum value. +template +constexpr auto conv_spec() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { InstTraits::kConvForwardSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + + if constexpr(InstTraits::kConvForwardSpecialization == Default) + { + return builder::ConvFwdSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3) + { + return builder::ConvFwdSpecialization::FILTER_3x3; + } + } + else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + + if constexpr(InstTraits::kConvBwdDataSpecialization == Default) + { + return builder::ConvBwdDataSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + } + else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + + if constexpr(InstTraits::kConvBwdWeightSpecialization == Default) + { + return builder::ConvBwdWeightSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC) + { + return builder::ConvBwdWeightSpecialization::ODD_C; + } + } +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts. +template +constexpr auto conv_layout() +{ + using InstTraits = InstanceTraits; + using ALayout = typename InstTraits::ALayout; + using BLayout = typename InstTraits::BLayout; + using ELayout = typename InstTraits::ELayout; + + namespace ctc = ck::tensor_layout::convolution; + + if constexpr(InstTraits::kSpatialDim == 1) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout1D::GNWC_GKXC_GNWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NWGC_GKXC_NWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKXC_NGKW; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKCX_NGKW; + } + } + else if constexpr(InstTraits::kSpatialDim == 2) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW; + } + } + else if constexpr(InstTraits::kSpatialDim == 3) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW; + } + } +} + +/// @brief Derives the data type from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). +template +constexpr builder::DataType conv_data_type() +{ + using InstTraits = InstanceTraits; + using ADataType = typename InstTraits::ADataType; + + if constexpr(std::is_same_v) + { + return builder::DataType::FP16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::BF16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP32; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::I8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::U8; + } + else + { + // Default fallback + return builder::DataType::FP32; + } +} + +/// @brief Derives the elementwise operation from op type. +/// @tparam ElementwiseOp Elementwise operation functor type. +/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation. +template +constexpr builder::ElementwiseOperation elementwise_op() +{ + constexpr std::string_view name = detail::elementwise_op_name(); + if constexpr(detail::case_insensitive_equal(name, "Bias")) + { + return builder::ElementwiseOperation::BIAS; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasClamp")) + { + return builder::ElementwiseOperation::BIAS_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + { + return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) + { + return builder::ElementwiseOperation::BILINEAR; + } + else if constexpr(detail::case_insensitive_equal(name, "Clamp")) + { + return builder::ElementwiseOperation::CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Scale")) + { + return builder::ElementwiseOperation::SCALE; + } + else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + { + return builder::ElementwiseOperation::PASS_THROUGH; + } +} + +/// @brief Derives a gemm padding from a kernel instance type. +/// @tparam Instance - A Device Kernel object type. +/// @return A `builder::GemmPadding` enum value corresponding to kernel padding. +template +constexpr builder::GemmPadding gemm_spec() +{ + using InstTraits = InstanceTraits; + using enum builder::GemmPadding; + using enum ck::tensor_operation::device::GemmSpecialization; + + constexpr auto gemm_spec = InstTraits::kGemmSpecialization; + + if constexpr(gemm_spec == Default) + { + return DEFAULT; + } + else if constexpr(gemm_spec == MPadding) + { + return M_PADDING; + } + else if constexpr(gemm_spec == NPadding) + { + return N_PADDING; + } + else if constexpr(gemm_spec == KPadding) + { + return K_PADDING; + } + else if constexpr(gemm_spec == MNPadding) + { + return MN_PADDING; + } + else if constexpr(gemm_spec == MKPadding) + { + return MK_PADDING; + } + else if constexpr(gemm_spec == NKPadding) + { + return NK_PADDING; + } + else if constexpr(gemm_spec == MNKPadding) + { + return MNK_PADDING; + } + else if constexpr(gemm_spec == OPadding) + { + return O_PADDING; + } + else if constexpr(gemm_spec == MOPadding) + { + return MO_PADDING; + } + else if constexpr(gemm_spec == NOPadding) + { + return NO_PADDING; + } + else if constexpr(gemm_spec == KOPadding) + { + return KO_PADDING; + } + else if constexpr(gemm_spec == MNOPadding) + { + return MNO_PADDING; + } + else if constexpr(gemm_spec == MKOPadding) + { + return MKO_PADDING; + } + else if constexpr(gemm_spec == NKOPadding) + { + return NKO_PADDING; + } + else if constexpr(gemm_spec == MNKOPadding) + { + return MNKO_PADDING; + } +} + +/// @brief Primary template for extracting convolution traits. +/// @details This struct is the main entry point for reflecting on a convolution +/// kernel's properties. It is specialized to handle different kinds of input types. +template +struct ConvTraits; + +/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`. +/// @details This is the primary specialization used to extract a comprehensive +/// set of traits directly from a fully-formed device kernel `Instance` type. +/// It uses `InstanceTraits` to access the kernel's template parameters. +template + requires requires { typename InstanceTraits; } +struct ConvTraits +{ + using InstTraits = InstanceTraits; + + // --- Signature Information --- + /// @brief The number of spatial dimensions in the convolution (1, 2, or 3). + static constexpr int spatial_dim = InstTraits::kSpatialDim; + /// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight). + static constexpr builder::ConvDirection direction = conv_direction(); + /// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK). + static constexpr auto layout = conv_layout(); + /// @brief The primary data type used in the computation (e.g., FP16, FP32). + static constexpr builder::DataType data_type = conv_data_type(); + + static constexpr builder::ElementwiseOperation input_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation weight_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation output_element_op = + elementwise_op(); + + /// @brief The GEMM specialization used by the kernel - padding + static constexpr auto gemm_padding = gemm_spec(); + /// @brief The convolution-specific specialization (e.g., Default, 1x1). + static constexpr auto conv_specialization = conv_spec(); + + // --- Algorithm Information --- + /// @brief The total number of threads in a thread block (workgroup). + static constexpr int thread_block_size = InstTraits::kBlockSize; + /// @brief The dimensions of the data tile processed by the thread block. + static constexpr DataTileInfo tile_dims = { + .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; + + /// @brief Configuration for the A-matrix (input) tile transfer. + static constexpr InputTileTransferInfo a_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; + + /// @brief Configuration for the B-matrix (weights) tile transfer. + static constexpr InputTileTransferInfo b_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; + + /// @brief Parameters for the warp-level GEMM computation. + static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}; + + /// @brief Configuration for the C-matrix (output) tile transfer. + static constexpr OutputTileTransferInfo c_tile_transfer = { + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; + + /// @brief Helper to safely get the pipeline version. + /// @details This is only available for some convolutions (e.g., forward). + /// If not present in `InstanceTraits`, it returns a default value. + template + static constexpr auto get_pipeline_version() + { + if constexpr(requires { T::kPipelineVersion; }) + { + return convert_pipeline_version(); + } + else + { + // Return a default or indicate not available + return builder::PipelineVersion::V1; + } + } + + /// @brief The block GEMM pipeline version used by the kernel. + static constexpr auto pipeline_version = get_pipeline_version(); + + /// @brief Helper to safely get the pipeline scheduler. + /// @details This is only available for some convolutions. If not present + /// in `InstanceTraits`, it returns a default value. + template + static constexpr auto get_pipeline_scheduler() + { + if constexpr(requires { T::kPipelineScheduler; }) + { + return convert_pipeline_scheduler(); + } + else if constexpr(requires { T::kLoopScheduler; }) + { + return convert_pipeline_scheduler(); + } + else + { + // Return a default or indicate not available + return builder::PipelineScheduler::DEFAULT; + } + } + + /// @brief The pipeline scheduler used by the kernel. + static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); +}; + +/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type. +/// @details This specialization provides backward compatibility for reflecting +/// on kernels defined via the `ConvBuilder` interface. It works by first +/// creating the `Instance` via the builder's factory, and then delegating +/// all trait extraction to the `ConvTraits` specialization. +template +struct ConvTraits> +{ + using Factory = builder::ConvFactory; + using Instance = typename Factory::Instance; + + // Delegate to Instance-based ConvTraits + using InstanceConvTraits = ConvTraits; + + // Forward all members from Instance-based traits + static constexpr int spatial_dim = InstanceConvTraits::spatial_dim; + static constexpr builder::ConvDirection direction = InstanceConvTraits::direction; + static constexpr auto layout = InstanceConvTraits::layout; + static constexpr builder::DataType data_type = InstanceConvTraits::data_type; + + static constexpr builder::ElementwiseOperation input_element_op = + InstanceConvTraits::input_element_op; + static constexpr builder::ElementwiseOperation weight_element_op = + InstanceConvTraits::weight_element_op; + static constexpr builder::ElementwiseOperation output_element_op = + InstanceConvTraits::output_element_op; + + static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding; + static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization; + + static constexpr int thread_block_size = InstanceConvTraits::thread_block_size; + static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims; + static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer; + static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer; + static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm; + static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer; + static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version; + static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler; +}; + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp index c9b45691cc..07f1b94b07 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp @@ -14,18 +14,9 @@ #pragma once -#include #include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include "instance_traits_util.hpp" +#include namespace ck_tile::reflect { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 7eb4c883b0..d6fc6da0d6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -15,6 +15,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 60f991b1fc..9edfa4d4c9 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -14,6 +14,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 95d1c94de4..c863d2306c 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -9,9 +9,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -371,4 +373,30 @@ constexpr std::string type_or_type_tuple_name() } } +/// @brief Makes a case insensitive comparison of two string views. +/// @param a First string view +/// @param b Second string view +/// @return Whether two string views a equal case insensitive +constexpr bool case_insensitive_equal(std::string_view a, std::string_view b) +{ + if(a.size() != b.size()) + return false; + + for(size_t i = 0; i < a.size(); ++i) + { + char c1 = a[i]; + char c2 = b[i]; + + // Convert to lowercase for comparison + if(c1 >= 'A' && c1 <= 'Z') + c1 += 32; + if(c2 >= 'A' && c2 <= 'Z') + c2 += 32; + + if(c1 != c2) + return false; + } + return true; +} + } // namespace ck_tile::reflect::detail diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 2650f0de16..2af10346e5 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -128,29 +128,14 @@ enum class ElementwiseOperation PASS_THROUGH }; -// Enums for the current block GEMM pipeline versions. -enum class BlockGemmPipelineVersion +// Enums for pipeline versions & schedulers +enum class PipelineVersion { V1, V2, V3, V4, - V5 -}; - -enum struct BlockGemmPipelineScheduler -{ - INTRAWAVE, - INTERWAVE, -}; - -// Enums for the gridwise GEMM pipeline versions. -enum class GridwiseGemmPipelineVersion -{ - V1, - V2, - V3, // Only used in stream-K implementation - V4, + V5, WEIGHT_ONLY }; @@ -186,9 +171,47 @@ enum class ConvFwdSpecialization FILTER_3x3 }; -enum class LoopScheduler +// Enums for the backward data convolution specialization. +enum class ConvBwdDataSpecialization { DEFAULT, + FILTER_1X1_STRIDE1_PAD0, +}; + +// Enums for the backward weight convolution specialization. +enum class ConvBwdWeightSpecialization +{ + DEFAULT, + FILTER_1X1_STRIDE1_PAD0, + FILTER_1X1_PAD0, + ODD_C, +}; + +// Enums for the Gemm padding. +enum class GemmPadding +{ + DEFAULT, + M_PADDING, + N_PADDING, + K_PADDING, + MN_PADDING, + MK_PADDING, + NK_PADDING, + MNK_PADDING, + O_PADDING, + MO_PADDING, + NO_PADDING, + KO_PADDING, + MNO_PADDING, + MKO_PADDING, + NKO_PADDING, + MNKO_PADDING, +}; + +enum class PipelineScheduler +{ + DEFAULT, + INTRAWAVE, INTERWAVE }; diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 8b5c4519a9..0cb3237f8c 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -64,6 +64,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) +add_ck_builder_test(test_conv_traits + conv/test_conv_traits.cpp) + # Function to add all test_ckb targets to a list function(collect_test_ckb_targets result_var) # Get all targets in current directory diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index b58de836de..123034eb77 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -27,7 +27,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V2, + PipelineVersion::V2, ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index b8dbf2ca97..240746f546 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } @@ -47,7 +47,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index aba2f29ffd..6366016707 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V3, + PipelineVersion::V3, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index 4d01323600..7b303a7bde 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V4, + PipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index a30158aa8e..b40dd0b0d7 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index c0b2e613a2..e0dad4e1a1 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V4, + PipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 0fea260eac..43ffb3f89a 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V1, + PipelineVersion::V1, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/test_conv_traits.cpp new file mode 100644 index 0000000000..ca453d2ad4 --- /dev/null +++ b/experimental/builder/test/conv/test_conv_traits.cpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +using ::testing::ElementsAre; + +// Test fixture for ConvTraits tests +class ConvTraitsTest : public ::testing::Test +{ +}; + +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::PipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + false>; // DirectLoad + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(Traits::warp_gemm.gemm_m, 32); + EXPECT_EQ(Traits::warp_gemm.gemm_n, 32); + EXPECT_EQ(Traits::warp_gemm.m_iter, 4); + EXPECT_EQ(Traits::warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE); + EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} +// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} +} // anonymous namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 1a78028862..accc4048dc 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -49,14 +49,14 @@ struct GridwiseWmmaGemm size_t n_per_wmma = 0; size_t m_wmma_per_wave = 0; size_t n_wmma_per_wave = 0; - GridwiseGemmPipelineVersion pipeline_version; + PipelineVersion pipeline_version; }; static_assert(ckb::GridwiseWmmaGemmDescriptor); struct BlockGemm { - BlockGemmPipelineVersion pipeline_version; - BlockGemmPipelineScheduler scheduler; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; }; static_assert(ckb::BlockGemmDescriptor); @@ -156,7 +156,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; size_t num_groups_to_merge; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; static_assert( ckb::ConvAlgorithmDescriptor); @@ -191,7 +191,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ConvFwdSpecialization fwd_specialization; GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; static_assert( ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index d145bdfd6c..7fd02a56f7 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -16,7 +16,7 @@ using namespace test; // Common test implementation template constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() { @@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() .src_access_order_b = {1, 0, 2}}; constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion, - .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + .scheduler = PipelineScheduler::INTRAWAVE}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock, @@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")); // Verify pipeline version is correct - if(FwdPipelineVersion == BlockGemmPipelineVersion::V1) + if(FwdPipelineVersion == PipelineVersion::V1) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3) + else if(FwdPipelineVersion == PipelineVersion::V3) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4) + else if(FwdPipelineVersion == PipelineVersion::V4) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5) + else if(FwdPipelineVersion == PipelineVersion::V5) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos); // Verify specialization is correct @@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle() .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, .num_groups_to_merge = 2, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; @@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1, - .pipeline_version = GridwiseGemmPipelineVersion::V1}; + .pipeline_version = PipelineVersion::V1}; constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, @@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() .fwd_specialization = FwdConvSpecialization, .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp index 01bb806789..219206c5ce 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -3,6 +3,8 @@ #pragma once +#include + namespace ck { namespace tensor_operation { namespace device { From 4533aa6dbab648adc1a496b6064cb79777c41cf5 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:42:22 -0800 Subject: [PATCH 2/9] Fix compilation errors with clang22. (#3164) * resolve compilation issue with clang22 * add __extension__ for __COUNTER__ usage in ck_tile --- include/ck_tile/core/utility/static_counter.hpp | 13 ++++++++----- profiler/src/profiler_operation_registry.hpp | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/core/utility/static_counter.hpp b/include/ck_tile/core/utility/static_counter.hpp index 84af3dd52f..4828e2e010 100644 --- a/include/ck_tile/core/utility/static_counter.hpp +++ b/include/ck_tile/core/utility/static_counter.hpp @@ -102,11 +102,14 @@ struct static_counter_uniq_; } #define MAKE_SC() \ - ck_tile::static_counter> {} -#define MAKE_SC_WITH(start_, step_) \ - ck_tile::static_counter, start_, step_> {} -#define NEXT_SC(c_) c_.next<__COUNTER__>() -#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>() + __extension__ ck_tile::static_counter> {} +#define MAKE_SC_WITH(start_, step_) \ + __extension__ ck_tile:: \ + static_counter, start_, step_> \ + { \ + } +#define NEXT_SC(c_) __extension__ c_.next<__COUNTER__>() +#define NEXT_SCI(c_, static_i_) __extension__ c_.next<__COUNTER__ + static_i_>() // Usage: // constexpr auto c = MAKE_SC() diff --git a/profiler/src/profiler_operation_registry.hpp b/profiler/src/profiler_operation_registry.hpp index 276b7b38dc..7e6d22d4ce 100644 --- a/profiler/src/profiler_operation_registry.hpp +++ b/profiler/src/profiler_operation_registry.hpp @@ -74,6 +74,6 @@ class ProfilerOperationRegistry final #define PP_CONCAT(x, y) PP_CONCAT_IMPL(x, y) #define PP_CONCAT_IMPL(x, y) x##y -#define REGISTER_PROFILER_OPERATION(name, description, operation) \ - static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \ +#define REGISTER_PROFILER_OPERATION(name, description, operation) \ + __extension__ static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \ ::ProfilerOperationRegistry::GetInstance().Add(name, description, operation) From 12922120d2567c3512048d7e8ed37e387a07bab6 Mon Sep 17 00:00:00 2001 From: joyeamd Date: Thu, 6 Nov 2025 14:29:03 +0800 Subject: [PATCH 3/9] add gfx11's barrier following SPG's reference (#3159) * add gfx11's barrier following SPG's reference * re-format the code * minor fix --------- Co-authored-by: ThomasNing --- include/ck_tile/core/arch/arch.hpp | 129 +++++++++++++++++++---------- 1 file changed, 83 insertions(+), 46 deletions(-) mode change 100644 => 100755 include/ck_tile/core/arch/arch.hpp diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp old mode 100644 new mode 100755 index 8620e7337c..5bf8548470 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -136,66 +136,103 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) #endif } -// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html +struct WaitcntLayoutGfx12 +{ // s_wait_loadcnt_dscnt: mem[13:8], ds[5:0] + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // mem + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // ds + CK_TILE_DEVICE static constexpr bool HAS_EXP = false; + + CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 8); } + CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 0); } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; } +}; + +struct WaitcntLayoutGfx11 +{ // vm[15:10] (6), lgkm[9:4] (6), exp unused + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; + CK_TILE_DEVICE static constexpr bool HAS_EXP = false; + + CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); } + CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; } +}; + +struct WaitcntLayoutLegacy +{ // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2 + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8] + CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4] + CK_TILE_DEVICE static constexpr bool HAS_EXP = true; + + CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) + { + c &= VM_MASK; + return ((c & 0xF) << 0) | ((c & 0x30) << 10); + } + CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 8); } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return ((c & EXP_MASK) << 4); } +}; + +// Select active layout +#if defined(__gfx12__) +using Waitcnt = WaitcntLayoutGfx12; +#elif defined(__gfx11__) +using Waitcnt = WaitcntLayoutGfx11; +#else +using Waitcnt = WaitcntLayoutLegacy; +#endif + +//---------------------------------------------- +// Public API: only from_* (constexpr templates) +//---------------------------------------------- struct waitcnt_arg { -#if defined(__gfx12__) - // use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8] - CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111; - - CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111; - CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111; - CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111; - - template - CK_TILE_DEVICE static constexpr index_t from_vmcnt() - { - static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); - return MAX & (cnt << 8); - } - - template - CK_TILE_DEVICE static constexpr index_t from_expcnt() - { - return 0; // no export in MI series - } - - template - CK_TILE_DEVICE static constexpr index_t from_lgkmcnt() - { - static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); - return MAX & cnt; - } + // kMax* exposed for callers; match field widths per-arch +#if defined(__gfx12__) || defined(__gfx11__) + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none #else - // bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210 - // [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV - CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111; - - CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111; - CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111; - CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111; + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split) + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits +#endif template CK_TILE_DEVICE static constexpr index_t from_vmcnt() { - static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); - return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10)); - } - - template - CK_TILE_DEVICE static constexpr index_t from_expcnt() - { - static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]"); - return MAX & (cnt << 4); + static_assert((cnt & ~Waitcnt::VM_MASK) == 0, "vmcnt out of range"); + return Waitcnt::pack_vm(cnt); } template CK_TILE_DEVICE static constexpr index_t from_lgkmcnt() { - static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]"); - return MAX & (cnt << 8); + static_assert((cnt & ~Waitcnt::LGKM_MASK) == 0, "lgkmcnt out of range"); + return Waitcnt::pack_lgkm(cnt); } + + template + CK_TILE_DEVICE static constexpr index_t from_expcnt() + { + if constexpr(Waitcnt::HAS_EXP) + { + // EXP_MASK only exists on legacy +#if !defined(__gfx12__) && !defined(__gfx11__) + static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range"); + return Waitcnt::pack_exp(cnt); +#else + (void)cnt; + return 0; #endif + } + else + { + static_assert(cnt == 0, "expcnt unsupported on this arch"); + return 0; + } + } }; template Date: Thu, 6 Nov 2025 11:26:30 +0100 Subject: [PATCH 4/9] [CK TILE] Convolution remove magic values (#3160) * [CK TILE] Refactor Conv configs and Conv Elementwise * fix * [CK TILE] Convolution remove magix values * fix partitioner --- .../20_grouped_convolution/conv_configs.hpp | 16 +- ...uped_convolution_backward_data_invoker.hpp | 209 +++++++++--------- ...ed_convolution_backward_weight_invoker.hpp | 90 ++++---- ...tion_backward_weight_two_stage_invoker.hpp | 108 +++++---- .../grouped_convolution_forward_invoker.hpp | 96 ++++---- ...nvolution_forward_large_tensor_invoker.hpp | 119 +++++----- .../utils/grouped_convolution_utils.hpp | 69 ++++-- 7 files changed, 355 insertions(+), 352 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index b9edf247cc..8a2a60a197 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -14,19 +14,11 @@ struct ConvConfigBase { - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; - static constexpr bool kPadK = true; - - static constexpr bool TransposeC = false; - static constexpr ck_tile::index_t VectorSizeA = 4; static constexpr ck_tile::index_t VectorSizeB = 8; static constexpr ck_tile::index_t VectorSizeC = 8; - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; @@ -210,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; + static constexpr ck_tile::index_t NumWaveGroups = 2; }; template diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index 14a533ffc9..d19d3ac8ec 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -22,8 +22,6 @@ struct GroupedConvolutionBackwardDataInvoker static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, const ck_tile::stream_config& s) { - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -32,36 +30,33 @@ struct GroupedConvolutionBackwardDataInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdData, + typename GroupedConvTraitsType::BsLayoutBwdData, + typename GroupedConvTraitsType::CLayoutBwdData, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardDataInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, InDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -93,95 +89,96 @@ struct GroupedConvolutionBackwardDataInvoker const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ConvConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - auto preprocess = [&]() { - ck_tile::hip_check_error(hipMemsetAsync( - kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); - }; - - ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; + auto preprocess = [&]() { + ck_tile::hip_check_error(hipMemsetAsync( + kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index 0e777c5f8a..81b9d402ce 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -21,8 +21,6 @@ struct GroupedConvolutionBackwardWeightInvoker static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, const ck_tile::stream_config& s) { - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -31,37 +29,34 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA; - constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB; - constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker InDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -101,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -127,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker AccDataType, WeiDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -136,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index a8e41438c8..8cef2bde65 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -23,8 +23,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker { using WorkspaceDataType = float; - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -33,36 +31,34 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 4; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -70,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker InDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -102,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -128,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker AccDataType, WorkspaceDataType, // C: Workspace normally Out typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -139,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ConvConfig::K_Warp_Tile, GemmPipelineProblem::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel(Kernel{}, grids, blocks, 0, kargs), - ck_tile::make_kernel(ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(shape[1], 1), // Input Stride - ck_tile::make_tuple(shape[1], 1), // Output Stride - input_tensors, - static_cast(c_ptr))); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), + ck_tile::make_kernel( + ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr))); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 2290f60d1f..7c8269d13c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -32,8 +32,6 @@ struct GroupedConvolutionForwardInvoker { std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; } - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -42,38 +40,34 @@ struct GroupedConvolutionForwardInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - constexpr ck_tile::index_t NumGroupsToMerge = 1; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutFwd, + typename GroupedConvTraitsType::BsLayoutFwd, + typename GroupedConvTraitsType::CLayoutFwd, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -81,13 +75,14 @@ struct GroupedConvolutionForwardInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, OutDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -116,21 +111,21 @@ struct GroupedConvolutionForwardInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -142,7 +137,7 @@ struct GroupedConvolutionForwardInvoker AccDataType, OutDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -151,10 +146,10 @@ struct GroupedConvolutionForwardInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionForwardKernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 4d983baac5..9d2752727c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -25,7 +25,6 @@ struct GroupedConvolutionForwardInvoker { std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; } - constexpr int kBlockPerCu = 1; // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< @@ -35,27 +34,18 @@ struct GroupedConvolutionForwardInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; - - using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits; + using GroupedConvTraitsTypeDefault = + ck_tile::GroupedConvTraits; using GroupedConvTraitsTypeLargeTensor = ck_tile::GroupedConvTraits; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsTypeDefault::AsLayoutFwd, + typename GroupedConvTraitsTypeDefault::BsLayoutFwd, + typename GroupedConvTraitsTypeDefault::CLayoutFwd, + GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC, + GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -88,13 +83,14 @@ struct GroupedConvolutionForwardInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd, + typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, OutDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsTypeDefault::VectorSizeA, + GroupedConvTraitsTypeDefault::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -116,9 +112,9 @@ struct GroupedConvolutionForwardInvoker using TransformType = ck_tile::TransformConvFwdToGemm; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -290,7 +286,7 @@ struct GroupedConvolutionForwardInvoker AccDataType, OutDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -299,10 +295,10 @@ struct GroupedConvolutionForwardInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; // Use split-image kernel if layout supports it, otherwise use regular kernel @@ -368,7 +364,8 @@ struct GroupedConvolutionForwardInvoker } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 9b5a60ee1f..8ea6cffa7d 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -74,6 +74,21 @@ struct GroupedConvTraits } public: + // Fixed values for Implicit GEMM + struct FixedGemmParams + { + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + static constexpr bool TransposeC = false; + static constexpr bool FixedVectorSize = true; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Persistent = false; + using ELayout = ck_tile::tensor_layout::gemm::RowMajor; + }; + // Compile time parameters static constexpr bool EnableSplitImage = EnableSplitImage_; static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_; static constexpr index_t NDimSpatial = NDimSpatial_; @@ -82,31 +97,43 @@ struct GroupedConvTraits using WeiLayout = WeiLayout_; using DsLayout = DsLayout_; using OutLayout = OutLayout_; + + // Forward Gemm Layouts + using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; + using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; + // Backward Data Gemm Layouts + using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + // Backward Weight Gemm Layouts + using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor; + using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + + template using GroupedConvImplicitGemmTraitsFwd = - TileGemmTraits; - using GroupedConvImplicitGemmTraitsBwdData = - TileGemmTraits; - using GroupedConvImplicitGemmTraitsBwdWeight = - TileGemmTraits; + TileGemmTraits; + template + using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits; + template + using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits; static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_; static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_; static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_; - static constexpr index_t NumDTensor = DsLayout::size(); + static constexpr ck_tile::index_t NumDTensor = DsLayout::size(); using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); }; From 18e083003fa25a661015542c39b1979200f361cf Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:46:26 +0100 Subject: [PATCH 5/9] [CK_BUILDER] Convolution description (#3163) * Add DirectLoad tparam & clean up headers. * Add convolution traits. * Update inline documentation. * Add more convolution specialization and gemm padding types. * Add additional helper functions & more tests to conv traits. * Fix tests cmake file. * Add case insensitive string comparison * Fix function name overlapping with variable name. * Unify pipeline version and scheduler enums. * Fix includes. * Update test conv traits with unified enums. * Update concepts etc with update unified enum * Fix ckb conv fwd test - unified enum usage. * Dump changes. * Add ostream overloads for all enum classes. * Update detailed() function in ConvDescription * Fix handling union based conv direction. * Add test & update conv description. * Refine tree view. * Update copyrights * Fix merge artifacts * Update detailed tree conv description * Fix clang-format --- .../include/ck_tile/builder/builder_utils.hpp | 62 ---- .../builder/conv_signature_predicates.hpp | 16 + .../builder/reflect/conv_description.hpp | 268 +++++++++++++++++ .../ck_tile/builder/reflect/conv_traits.hpp | 2 +- .../builder/reflect/instance_traits_util.hpp | 5 +- .../builder/reflect/tree_formatter.hpp | 106 +++++++ .../builder/include/ck_tile/builder/types.hpp | 275 ++++++++++++++++++ experimental/builder/test/CMakeLists.txt | 3 + .../builder/test/test_conv_description.cpp | 169 +++++++++++ 9 files changed, 842 insertions(+), 64 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp create mode 100644 experimental/builder/test/test_conv_description.cpp diff --git a/experimental/builder/include/ck_tile/builder/builder_utils.hpp b/experimental/builder/include/ck_tile/builder/builder_utils.hpp index 5b4981c630..f16d96bec6 100644 --- a/experimental/builder/include/ck_tile/builder/builder_utils.hpp +++ b/experimental/builder/include/ck_tile/builder/builder_utils.hpp @@ -78,66 +78,4 @@ struct UnsupportedEnumValue { }; -// Helper functions to convert enums to strings -constexpr std::string_view ConvDirectionToString(ConvDirection dir) -{ - switch(dir) - { - case ConvDirection::FORWARD: return "Forward"; - case ConvDirection::BACKWARD_DATA: return "Backward Data"; - case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight"; - default: return "Unknown"; - } -} - -constexpr std::string_view DataTypeToString(DataType dt) -{ - switch(dt) - { - case DataType::FP16: return "FP16"; - case DataType::FP32: return "FP32"; - case DataType::BF16: return "BF16"; - case DataType::FP8: return "FP8"; - case DataType::I8: return "I8"; - case DataType::U8: return "U8"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout1D layout) -{ - switch(layout) - { - case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK"; - case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK"; - case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW"; - case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout2D layout) -{ - switch(layout) - { - case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK"; - case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK"; - case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW"; - case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout3D layout) -{ - switch(layout) - { - case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK"; - case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK"; - case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW"; - case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW"; - default: return "Unknown"; - } -} - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp index f016a342d3..3869c7b538 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp @@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); // Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); // Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); @@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward = // Predicate for DeviceGroupedConvBwdWeight operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); // Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); // Predicate for DeviceGroupedConvBwdWeightMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); // Predicate for DeviceGroupedConvBwdWeight_Dl operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); @@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight = // Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); // Predicate for DeviceGroupedConvBwdDataMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); // Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp new file mode 100644 index 0000000000..0b58f5a3b7 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +/// @file conv_description.hpp +/// @brief Provides human-readable descriptions of ConvBuilder configurations + +namespace ck_tile::reflect::conv { + +struct ConvSignatureInfo +{ + int spatial_dim; + builder::ConvDirection direction; + std::variant + layout; + builder::DataType data_type; + builder::ElementwiseOperation input_element_op; + builder::ElementwiseOperation weight_element_op; + builder::ElementwiseOperation output_element_op; +}; + +// Algorithm information - groups all algorithm-related configuration +struct GemmAlgorithmInfo +{ + int thread_block_size; + DataTileInfo tile_dims; + WarpGemmParams warp_gemm; + InputTileTransferInfo a_tile_transfer; + InputTileTransferInfo b_tile_transfer; + OutputTileTransferInfo c_tile_transfer; + builder::PipelineVersion pipeline_version; + builder::PipelineScheduler pipeline_scheduler; + std::variant + conv_specialization; + builder::GemmPadding padding; +}; + +// Provides human-readable descriptions of ConvBuilder configurations. +struct ConvDescription +{ + ConvSignatureInfo signature; + GemmAlgorithmInfo algorithm; + + // Brief one-line summary + std::string brief() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " << signature.direction << " convolution"; + return oss.str(); + } + + // Detailed hierarchical description + std::string detailed() const + { + TreeFormatter f; + f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel"); + f.writeLine(1, "Signature"); + f.writeLine(2, "Tensor Type: ", signature.data_type); + f.writeLine(2, "Memory Layout: ", signature.layout); + f.writeLine(2, "Input elementwise operation: ", signature.input_element_op); + f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op); + f.writeLast(2, "Output elementwise operation: ", signature.output_element_op); + + f.writeLine(1, "Algorithm"); + // Compute Block section + f.writeLine(2, "Thread block size: ", algorithm.thread_block_size); + f.writeLine(2, + "Data tile size: ", + algorithm.tile_dims.m, + "×", + algorithm.tile_dims.n, + "×", + algorithm.tile_dims.k); + f.writeLine(2, "Gemm padding: ", algorithm.padding); + f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization); + // Pipeline section + f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version); + f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler); + f.writeLine(2, "Warp Gemm parameters: "); + f.writeLine( + 3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n); + f.writeLast(3, + "Number of warp gemm iterations: ", + algorithm.warp_gemm.m_iter, + "×", + algorithm.warp_gemm.n_iter); + + // Memory Access section + f.writeLine(2, "Memory access:"); + + f.writeLine(3, "A Tile transfer: "); + f.writeLine(4, + "Tile dimensions: ", + algorithm.a_tile_transfer.tile_dimensions.k0, + "×", + algorithm.a_tile_transfer.tile_dimensions.m_or_n, + "×", + algorithm.a_tile_transfer.tile_dimensions.k1, + "×"); + f.writeLine( + 4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1); + f.writeLine(4, + "Spatial thread distribution over the data tile: ", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]); + f.writeLine(4, + "The order of accessing data tile axes: ", + algorithm.a_tile_transfer.transfer_params.src_access_order[0], + "×", + algorithm.a_tile_transfer.transfer_params.src_access_order[1], + "×", + algorithm.a_tile_transfer.transfer_params.src_access_order[2]); + f.writeLine(4, + "Vectorized memory access axis index (with contiguous memory): ", + algorithm.a_tile_transfer.transfer_params.src_vector_dim); + f.writeLine(4, + "Vector access (GMEM read) instruction size: ", + algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector); + f.writeLine(4, + "Vector access (LDS write) instruction size: ", + algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + f.writeLast(4, + "LDS data layout padding (to prevent bank conflicts): ", + algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + + f.writeLine(3, "B Tile transfer: "); + f.writeLine(4, + "Tile dimensions: ", + algorithm.b_tile_transfer.tile_dimensions.k0, + "×", + algorithm.b_tile_transfer.tile_dimensions.m_or_n, + "×", + algorithm.b_tile_transfer.tile_dimensions.k1, + "×"); + f.writeLine( + 4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1); + f.writeLine(4, + "Spatial thread distribution over the data tile: ", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]); + f.writeLine(4, + "The order of accessing data tile axes: ", + algorithm.b_tile_transfer.transfer_params.src_access_order[0], + "×", + algorithm.b_tile_transfer.transfer_params.src_access_order[1], + "×", + algorithm.b_tile_transfer.transfer_params.src_access_order[2]); + f.writeLine(4, + "Vectorized memory access axis index (with contiguous memory): ", + algorithm.b_tile_transfer.transfer_params.src_vector_dim); + f.writeLine(4, + "Vector access (GMEM read) instruction size: ", + algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector); + f.writeLine(4, + "Vector access (LDS write) instruction size: ", + algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + f.writeLast(4, + "LDS data layout padding (to prevent bank conflicts): ", + algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + + f.writeLast(3, "C Tile transfer: "); + f.writeLine(4, + "Data shuffle (number of gemm instructions per iteration): ", + algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, + "×", + algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle); + f.writeLine(4, + "Spatial thread distribution used to store data: ", + algorithm.c_tile_transfer.thread_cluster_dims[0], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[1], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[2], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[3]); + f.writeLast(4, + "Vector access (GMEM write) instruction size: ", + algorithm.c_tile_transfer.scalar_per_vector); + f.writeLast(2); + f.writeLast(1); + return f.getString(); + } + + // Educational explanation of optimization choices + std::string explain() const + { + std::ostringstream oss; + // Placeholder for future implementation + return oss.str(); + } + + // Performance characteristics and use case guidance + std::string suggest() const + { + std::ostringstream oss; + // Placeholder for future implementation + return oss.str(); + } +}; + +// Helper concept to detect if a type has InstanceTraits specialization +template +concept HasInstanceTraits = requires { typename InstanceTraits; }; + +// Helper concept to detect ConvBuilder types +template +concept IsConvBuilder = requires { + typename T::Factory; + typename T::Instance; +}; + +// Primary factory function: Create ConvDescription from Instance type directly +template + requires HasInstanceTraits +ConvDescription Describe() +{ + using Traits = ConvTraits; + + return ConvDescription{ + .signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim, + .direction = Traits::direction, + .layout = Traits::layout, + .data_type = Traits::data_type, + .input_element_op = Traits::input_element_op, + .weight_element_op = Traits::weight_element_op, + .output_element_op = Traits::output_element_op}, + .algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size, + .tile_dims = Traits::tile_dims, + .warp_gemm = Traits::warp_gemm, + .a_tile_transfer = Traits::a_tile_transfer, + .b_tile_transfer = Traits::b_tile_transfer, + .c_tile_transfer = Traits::c_tile_transfer, + .pipeline_version = Traits::pipeline_version, + .pipeline_scheduler = Traits::pipeline_scheduler, + .conv_specialization = Traits::conv_specialization, + .padding = Traits::gemm_padding}}; +} + +// Backward compatibility: Create ConvDescription from Builder type +template + requires IsConvBuilder && (!HasInstanceTraits) +ConvDescription Describe() +{ + // Delegate to Instance-based version + using Instance = typename Builder::Instance; + return Describe(); +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index a74d77d155..86cf11f647 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index c863d2306c..e4d154ae10 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -13,7 +13,10 @@ #include #include #include -#include +#include +#include +#include +#include #include #include #include diff --git a/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp new file mode 100644 index 0000000000..6a80a994ee --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include + +namespace ck_tile::reflect { + +// Helper class for formatting hierarchical tree structures with proper indentation +// and tree-drawing characters (├─, └─, │, etc.) +// +// Example Usage: +// +// TreeFormatter f; +// f.writeLine(0, "Root"); +// f.writeLine(1, "Branch 1"); +// f.writeLine(2, "Item 1a"); +// f.writeLast(2, "Item 1b"); +// f.writeLast(1, "Branch 2"); +// f.writeLast(2, "Item 2a"); +// std::cout << f.getString() << "\n"; +// +// Generated Output: +// +// Root +// ├─ Branch 1 +// │ ├─ Item 1a +// │ └─ Item 1b +// └─ Branch 2 +// └─ Item 2a +class TreeFormatter +{ + public: + TreeFormatter() = default; + + // Write a line at the specified indentation level (branch continues after this) + template + void writeLine(int indent_level, Args&&... args) + { + writeLineImpl(indent_level, false, std::forward(args)...); + } + + // Write the last line at the specified indentation level (branch ends) + template + void writeLast(int indent_level, Args&&... args) + { + writeLineImpl(indent_level, true, std::forward(args)...); + } + + // Get the formatted string (removes trailing newline if present) + std::string getString() const + { + std::string result = oss_.str(); + if(!result.empty() && result.back() == '\n') + { + result.pop_back(); + } + return result; + } + + private: + std::ostringstream oss_; + std::vector is_last_at_level_; // Tracks which levels have ended + + // Implementation of line writing with tree symbols + template + void writeLineImpl(int indent_level, bool is_last, Args&&... args) + { + // Ensure we have enough tracking space + if(static_cast(indent_level) >= is_last_at_level_.size()) + { + is_last_at_level_.resize(indent_level + 1, false); + // Level 0 (root) should always be treated as "last" since it has no tree symbols + if(is_last_at_level_.size() > 0) + { + is_last_at_level_[0] = true; + } + } + + // Draw the tree structure + // Start from level 1 (skip level 0 which is the root with no symbols) + for(int i = 1; i < indent_level; ++i) + { + // For all parent levels, draw vertical line or space based on whether they ended + oss_ << (is_last_at_level_[i] ? " " : "│ "); + } + + // Draw the branch symbol for the current level + if(indent_level > 0) + { + oss_ << (is_last ? "└─ " : "├─ "); + } + + // Write the content using fold expression with direct stream insertion + ((oss_ << std::forward(args)), ...); + + oss_ << '\n'; + + // Update tracking for this level AFTER writing the line + // This ensures future lines at deeper levels know if this level ended + is_last_at_level_[indent_level] = is_last; + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 2af10346e5..a58c994288 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + namespace ck_tile::builder { enum class DataType @@ -215,4 +219,275 @@ enum class PipelineScheduler INTERWAVE }; +// ostream operator overloads for enum classes +inline std::ostream& operator<<(std::ostream& os, DataType dt) +{ + using enum DataType; + switch(dt) + { + case FP16: return os << "FP16"; + case FP32: return os << "FP32"; + case BF16: return os << "BF16"; + case FP8: return os << "FP8"; + case I8: return os << "I8"; + case U8: return os << "U8"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) +{ + using enum ConvDirection; + switch(dir) + { + case FORWARD: return os << "Forward"; + case BACKWARD_DATA: return os << "Backward Data"; + case BACKWARD_WEIGHT: return os << "Backward Weight"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout) +{ + using enum GroupConvLayout1D; + switch(layout) + { + case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK"; + case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK"; + case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW"; + case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout) +{ + using enum GroupConvLayout2D; + switch(layout) + { + case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK"; + case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK"; + case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW"; + case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout) +{ + using enum GroupConvLayout3D; + switch(layout) + { + case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK"; + case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK"; + case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW"; + case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op) +{ + using enum FwdGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK: + return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"; + case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle: + return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"; + case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle: + return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"; + case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3: + return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"; + case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor: + return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op) +{ + using enum BwdDataGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD"; + case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle: + return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"; + case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1: + return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op) +{ + using enum BwdWeightGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight"; + case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl"; + case DeviceGroupedConvBwdWeight_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; + case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3: + return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + case DeviceGroupedConvBwdWeight_Wmma_CShuffle: + return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; + case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"; + case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD"; + case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op) +{ + using enum ElementwiseOperation; + switch(op) + { + case BIAS: return os << "BIAS"; + case BIAS_CLAMP: return os << "BIAS_CLAMP"; + case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP"; + case BILINEAR: return os << "BILINEAR"; + case CLAMP: return os << "CLAMP"; + case SCALE: return os << "SCALE"; + case PASS_THROUGH: return os << "PASS_THROUGH"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver) +{ + using enum PipelineVersion; + switch(ver) + { + case V1: return os << "V1"; + case V2: return os << "V2"; + case V3: return os << "V3"; + case V4: return os << "V4"; + case V5: return os << "V5"; + case WEIGHT_ONLY: return os << "WEIGHT_ONLY"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec) +{ + using enum GemmSpecialization; + switch(spec) + { + case Default: return os << "Default"; + case MPadding: return os << "MPadding"; + case NPadding: return os << "NPadding"; + case KPadding: return os << "KPadding"; + case MNPadding: return os << "MNPadding"; + case MKPadding: return os << "MKPadding"; + case NKPadding: return os << "NKPadding"; + case MNKPadding: return os << "MNKPadding"; + case OPadding: return os << "OPadding"; + case MOPadding: return os << "MOPadding"; + case NOPadding: return os << "NOPadding"; + case KOPadding: return os << "KOPadding"; + case MNOPadding: return os << "MNOPadding"; + case MKOPadding: return os << "MKOPadding"; + case NKOPadding: return os << "NKOPadding"; + case MNKOPadding: return os << "MNKOPadding"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec) +{ + using enum ConvFwdSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + case FILTER_3x3: return os << "FILTER_3x3"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec) +{ + using enum ConvBwdDataSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec) +{ + using enum ConvBwdWeightSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0"; + case ODD_C: return os << "ODD_C"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GemmPadding padding) +{ + using enum GemmPadding; + switch(padding) + { + case DEFAULT: return os << "DEFAULT"; + case M_PADDING: return os << "M_PADDING"; + case N_PADDING: return os << "N_PADDING"; + case K_PADDING: return os << "K_PADDING"; + case MN_PADDING: return os << "MN_PADDING"; + case MK_PADDING: return os << "MK_PADDING"; + case NK_PADDING: return os << "NK_PADDING"; + case MNK_PADDING: return os << "MNK_PADDING"; + case O_PADDING: return os << "O_PADDING"; + case MO_PADDING: return os << "MO_PADDING"; + case NO_PADDING: return os << "NO_PADDING"; + case KO_PADDING: return os << "KO_PADDING"; + case MNO_PADDING: return os << "MNO_PADDING"; + case MKO_PADDING: return os << "MKO_PADDING"; + case NKO_PADDING: return os << "NKO_PADDING"; + case MNKO_PADDING: return os << "MNKO_PADDING"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched) +{ + using enum PipelineScheduler; + switch(sched) + { + case DEFAULT: return os << "DEFAULT"; + case INTRAWAVE: return os << "INTRAWAVE"; + case INTERWAVE: return os << "INTERWAVE"; + default: return os << "Unknown"; + } +} + +// ostream operator overload for std::variant of layout types +inline std::ostream& +operator<<(std::ostream& os, + const std::variant& layout) +{ + std::visit([&os](const auto& l) { os << l; }, layout); + return os; +} + +// ostream operator overload for std::variant of convolution specializations +inline std::ostream& operator<<(std::ostream& os, + const std::variant& spec) +{ + std::visit([&os](const auto& s) { os << s; }, spec); + return os; +} + } // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 0cb3237f8c..b776edbcde 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -67,6 +67,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test add_ck_builder_test(test_conv_traits conv/test_conv_traits.cpp) +add_ck_builder_test(test_conv_description + test_conv_description.cpp) + # Function to add all test_ckb targets to a list function(collect_test_ckb_targets result_var) # Get all targets in current directory diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp new file mode 100644 index 0000000000..97af4af795 --- /dev/null +++ b/experimental/builder/test/test_conv_description.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include "testing_utils.hpp" +#include "impl/conv_signature_types.hpp" +#include "impl/conv_algorithm_types.hpp" + +namespace { + +namespace ckb = ck_tile::builder; +namespace ckr = ck_tile::reflect::conv; +namespace ckt = ck_tile::test; + +// Defines the signature of the convolution operation to be tested. +// This includes dimensionality, direction, data layout, and data type. +struct ConvSignature +{ + int spatial_dim = 2; + ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; + ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + ckb::DataType data_type = ckb::DataType::FP16; + ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH; + ckb::GroupConvDeviceOp device_operation = + ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; +}; +static_assert(ckb::ConvSignatureDescriptor); + +struct DefaultAlgorithm +{ + ckb::test::ThreadBlock thread_block{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; + + ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 16, + .n_per_xdl = 16, + .m_xdl_per_wave = 4, + .n_xdl_per_wave = 4}; + + ckb::test::BlockTransferABC block_transfer{ + .block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8}, + .block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {.order = {0, 1, 2}}, + .block_transfer_access_order_b = {.order = {0, 1, 2}}, + .src_access_order_a = {.order = {0, 1, 2}}, + .src_access_order_b = {.order = {0, 1, 2}}}; + + ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT; + ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; + ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, + .scheduler = ckb::PipelineScheduler::INTRAWAVE}; +}; +static_assert(ckb::ConvAlgorithmDescriptor); + +TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription) +{ + static constexpr const ConvSignature SIGNATURE; + static constexpr const DefaultAlgorithm ALGORITHM; + using Builder = ckb::ConvBuilder; + EXPECT_THAT(ckr::Describe().brief(), ckt::StringEqWithDiff("2D Forward convolution")); +} + +TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) +{ + static constexpr const ConvSignature SIGNATURE; + static constexpr const DefaultAlgorithm ALGORITHM; + using Builder = ckb::ConvBuilder; + EXPECT_THAT(ckr::Describe().detailed(), + ckt::StringEqWithDiff( // + "2D Forward Convolution Kernel\n" + "├─ Signature\n" + "│ ├─ Tensor Type: FP16\n" + "│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n" + "│ ├─ Input elementwise operation: PASS_THROUGH\n" + "│ ├─ Weights elementwise operation: PASS_THROUGH\n" + "│ └─ Output elementwise operation: PASS_THROUGH\n" + "├─ Algorithm\n" + "│ ├─ Thread block size: 256\n" + "│ ├─ Data tile size: 256×256×32\n" + "│ ├─ Gemm padding: DEFAULT\n" + "│ ├─ Convolution specialization: DEFAULT\n" + "│ ├─ Pipeline version: V4\n" + "│ ├─ Pipeline scheduler: INTRAWAVE\n" + "│ ├─ Warp Gemm parameters: \n" + "│ │ ├─ subtile size: 16×16\n" + "│ │ └─ Number of warp gemm iterations: 4×4\n" + "│ ├─ Memory access:\n" + "│ │ ├─ A Tile transfer: \n" + "│ │ │ ├─ Tile dimensions: 4×256×8×\n" + "│ │ │ ├─ The innermost K subdimension size: 8\n" + "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" + "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + "│ │ ├─ B Tile transfer: \n" + "│ │ │ ├─ Tile dimensions: 4×256×8×\n" + "│ │ │ ├─ The innermost K subdimension size: 8\n" + "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" + "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + "│ │ └─ C Tile transfer: \n" + "│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + "│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + "│ │ └─ Vector access (GMEM write) instruction size: 8\n" + "│ └─ \n" + "└─ ")); +} + +// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory +// does not have a specialization for backward data convolutions. The test fails with: +// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'" +// +// To enable this test, a ConvFactory specialization for backward data operations must be +// implemented first. +// +// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription) +// { +// struct BackwardDataSignature +// { +// int spatial_dim = 2; +// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA; +// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; +// ckb::DataType data_type = ckb::DataType::FP16; +// ckb::ElementwiseOperation elementwise_operation = +// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation = +// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; +// }; +// static_assert(ckb::ConvSignatureDescriptor); +// +// static constexpr const BackwardDataSignature SIGNATURE; +// static constexpr const DefaultAlgorithm ALGORITHM; +// using Builder = ckb::ConvBuilder; +// +// // Verify Brief works +// EXPECT_THAT(ckr::Describe().brief(), +// ckt::StringEqWithDiff("2D Backward Data convolution")); +// +// // Verify detailed works - to be updated once ConvFactory is implemented +// EXPECT_THAT(ckr::Describe().detailed(), +// ckt::StringEqWithDiff("PLACEHOLDER")); +// } +} // namespace From 76c4c12f5959adcd56d1627a1d1ce885deb9d096 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Fri, 7 Nov 2025 00:07:39 +0100 Subject: [PATCH 6/9] Add .clangd and CMakeUserPresets.json to .gitignore (#3171) --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index 6641e5bc58..2641a661d8 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,12 @@ docs/doxygen/xml cmake-build*/ build*/ +# LSP configuration +.clangd + +# User-defined CMake presets +CMakeUserPresets.json + # Python virtualenv .venv/ From 5f3cae3e28a042e411afcd2e54b16cc6909c5bbb Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Fri, 7 Nov 2025 02:29:48 +0200 Subject: [PATCH 7/9] [CK_BUILDER]ckb add remining fwd conv device ops (#3155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add device operation to conv signature. Use unions to hold conv layouts and device operations. * Add predicates for all device op instances. * Use the device op signature for validation. * Fix ckb CMakeLists.txt file for tests. * Fix building CK Builder instance traits after the introduction of direct load template parameter in CK. * Fix clang-formatting. * add device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk * Add full DL configurability with Option A implementation - Added 5 DL descriptor structs (39 configurable parameters) - Added 10 C++20 concepts for type-safe validation - Updated factory to read all parameters from descriptors - Updated test helper to populate all descriptors - All tests passing (13/13 including 3 new DL tests) * Add factory and test support for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add factory specialization for Large_Tensor device operation (conv_factory.hpp lines 1145-1265) - Add macro collision workaround using pragma push/pop (conv_factory.hpp lines 43-51) - Add test helper function run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add builder test file test_ckb_conv_fwd_2d_large_tensor_fp16.cpp with 2 test cases - Update CMakeLists.txt to include new test file - Reuse existing ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle descriptor - Map all 42 template parameters identical to regular XDL CShuffle - All 15 builder tests passing including 2 new Large_Tensor tests Completes Task 350: All 4 forward convolution device operations now supported in CK Builder. * Update copyright headers to new format - Change copyright format to: Copyright (C) Advanced Micro Devices, Inc., or its affiliates. - Reorder headers: Copyright first, then SPDX-License-Identifier - Updated files: * experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp * experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp * experimental/builder/include/ck_tile/builder/device_op_types.hpp * fix c++ 18 format * Fix clang-format-18 error in device_op_types.hpp --------- Co-authored-by: Ville Pietilä Co-authored-by: Ville Pietilä <188998872+vpietila-amd@users.noreply.github.com> --- .../builder/conv_algorithm_concepts.hpp | 83 ++++++ .../include/ck_tile/builder/conv_factory.hpp | 271 ++++++++++++++++++ .../ck_tile/builder/device_op_types.hpp | 22 ++ experimental/builder/test/CMakeLists.txt | 2 + .../conv/test_ckb_conv_fwd_2d_dl_fp16.cpp | 69 +++++ ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 53 ++++ .../test/impl/conv_algorithm_types.hpp | 80 ++++++ .../test/utils/ckb_conv_test_common.hpp | 145 ++++++++++ 8 files changed, 725 insertions(+) create mode 100644 experimental/builder/include/ck_tile/builder/device_op_types.hpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index e43f910a73..6006efe4f8 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -183,4 +183,87 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; +/******************************************** */ +/* DL-specific descriptors and requirements */ +/******************************************** */ + +// Concept for DL thread configuration +template +concept DlThreadConfigDescriptor = requires(T t) { + { t.k0_per_block } -> std::convertible_to; + { t.k1 } -> std::convertible_to; + { t.m1_per_thread } -> std::convertible_to; + { t.n1_per_thread } -> std::convertible_to; + { t.k_per_thread } -> std::convertible_to; +}; + +// Concept for DL thread cluster +template +concept DlThreadClusterDescriptor = requires(T t) { + { t.m1_xs } -> std::convertible_to>; + { t.n1_xs } -> std::convertible_to>; +}; + +// Concept for DL block transfer K0_M0_M1_K1 format +template +concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) { + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; +}; + +// Concept for DL block transfer K0_N0_N1_K1 format +template +concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) { + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; +}; + +// Concept for DL C thread transfer +template +concept DlCThreadTransferDescriptor = requires(T t) { + { t.src_dst_access_order } -> std::convertible_to>; + { t.src_dst_vector_dim } -> std::convertible_to; + { t.dst_scalar_per_vector } -> std::convertible_to; +}; + +// Concept to check if algorithm specifies DL thread config +template +concept SpecifiesDlThreadConfig = requires { + { T::dl_thread_config } -> DlThreadConfigDescriptor; +}; + +// Concept to check if algorithm specifies DL thread cluster +template +concept SpecifiesDlThreadCluster = requires { + { T::dl_thread_cluster } -> DlThreadClusterDescriptor; +}; + +// Concept to check if algorithm specifies DL A block transfer +template +concept SpecifiesDlBlockTransferA = requires { + { T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor; +}; + +// Concept to check if algorithm specifies DL B block transfer +template +concept SpecifiesDlBlockTransferB = requires { + { T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor; +}; + +// Concept to check if algorithm specifies DL C thread transfer +template +concept SpecifiesDlCThreadTransfer = requires { + { T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 1ccc190ba2..e40199987d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -36,9 +36,21 @@ #pragma once +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +// WORKAROUND: Macro namespace collision in upstream CK device operation headers. +// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and +// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define +// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors. +// Use pragma push/pop to isolate the Large_Tensor header's macro scope. +#pragma push_macro("GridwiseGemmTemplateParameters") +#ifdef GridwiseGemmTemplateParameters +#undef GridwiseGemmTemplateParameters +#endif +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#pragma pop_macro("GridwiseGemmTemplateParameters") #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" @@ -990,4 +1002,263 @@ struct ConvFactory GRIDWISE_GEMM_PIPELINE_VERSION>; }; +// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance +// of a grouped forward convolution kernel using Direct Load (DL) approach. +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(SpecifiesDlThreadConfig, + "DL algorithm must specify thread config."); + static_assert(SpecifiesDlThreadCluster, + "DL algorithm must specify thread cluster."); + static_assert(SpecifiesDlBlockTransferA, + "DL algorithm must specify A block transfer."); + static_assert(SpecifiesDlBlockTransferB, + "DL algorithm must specify B block transfer."); + static_assert(SpecifiesDlCThreadTransfer, + "DL algorithm must specify C thread transfer."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + + // DL-specific parameters from algorithm descriptor + static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config; + static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; + static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; + static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; + static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread; + static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; + + // Thread cluster from descriptor + static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster; + using M1N1ThreadClusterM1Xs = to_sequence_v; + using M1N1ThreadClusterN1Xs = to_sequence_v; + + // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format + static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a; + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using ABlockTransferSrcAccessOrder = to_sequence_v; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + + // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format + static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b; + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using BBlockTransferSrcAccessOrder = to_sequence_v; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + + // C Thread Transfer from descriptor + static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer; + using CThreadTransferSrcDstAccessOrder = to_sequence_v; + static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; + static constexpr ck::index_t CThreadTransferDstScalarPerVector = + DL_C_TRANSFER.dst_scalar_per_vector; + + // The DL forward convolution kernel class instance + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + SPATIAL_DIM, + typename Types::ADataType, + typename Types::BDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Types::AccDataType, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + FWD_CONV_SPECIALIZATION, + GEMM_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + K0PerBlock, + K1, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM1Xs, + M1N1ThreadClusterN1Xs, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; +}; + +// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance +// of a grouped forward convolution kernel with large tensor support (N-splitting). +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesGridwiseXdlGemm, + "The convolution algorithm descriptor must specify gridwise GEMM info."); + static_assert(SpecifiesBlockTransfer, + "The convolution algorithm descriptor must specify block transfer info."); + static_assert(SpecifiesLdsTransfer, + "The convolution algorithm descriptor must specify LDS transfer info."); + static_assert( + SpecifiesThreadClusterAccessOrder, + "The convolution algorithm descriptor must specify thread cluster access order info."); + static_assert(SpecifiesSourceAccessOrder, + "The convolution algorithm descriptor must specify source access order info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(SpecifiesNumPrefetchStages, + "The convolution algorithm descriptor must specify number of prefetch stages."); + static_assert(SpecifiesLoopScheduler, + "The convolution algorithm descriptor must specify loop scheduler."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + factory_internal::SetFwdConvABlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + factory_internal::SetFwdConvBBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = + factory_internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance with large tensor support. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::AComputeType, + typename Types::BComputeType, + LOOP_SCHEDULER>; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/device_op_types.hpp b/experimental/builder/include/ck_tile/builder/device_op_types.hpp new file mode 100644 index 0000000000..0e779fdf4e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/device_op_types.hpp @@ -0,0 +1,22 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile::builder { + +// Enumeration for CK Device Operation types. +// This allows the builder to select which device operation template to instantiate +// based on the user's requirements. +enum class DeviceOpType +{ + // Forward Convolution - Non-grouped + CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet) + + // Forward Convolution - Grouped + GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to: + // DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +}; + +} // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index b776edbcde..43c4fd4857 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -43,6 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp + conv/test_ckb_conv_fwd_2d_dl_fp16.cpp + conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp conv/test_ckb_conv_fwd_3d_bf16.cpp conv/test_ckb_conv_fwd_3d_fp16.cpp conv/test_ckb_conv_fwd_3d_fp32.cpp) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp new file mode 100644 index 0000000000..12730bab19 --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -0,0 +1,69 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); +} + +TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); +} + +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp new file mode 100644 index 0000000000..0216c5907d --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -0,0 +1,53 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; + + run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::DEFAULT>(); +} + +TEST( + FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 128, + .tile_size = {.m = 128, .n = 128, .k = 32}}; + + run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index accc4048dc..88c5b5787a 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -214,4 +214,84 @@ static_assert( static_assert( ckb::SpecifiesLoopScheduler); +// DL-specific descriptors +struct DlThreadConfig +{ + size_t k0_per_block; + size_t k1; + size_t m1_per_thread; + size_t n1_per_thread; + size_t k_per_thread; +}; +static_assert(ckb::DlThreadConfigDescriptor); + +struct DlThreadCluster +{ + std::array m1_xs; // e.g., {8, 2} + std::array n1_xs; // e.g., {8, 2} +}; +static_assert(ckb::DlThreadClusterDescriptor); + +struct DlBlockTransferK0M0M1K1 +{ + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; +}; +static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor); + +struct DlBlockTransferK0N0N1K1 +{ + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; +}; +static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor); + +struct DlCThreadTransfer +{ + std::array src_dst_access_order; + size_t src_dst_vector_dim; + size_t dst_scalar_per_vector; +}; +static_assert(ckb::DlCThreadTransferDescriptor); + +struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +{ + ThreadBlock thread_block; + ConvFwdSpecialization fwd_specialization; + GemmSpecialization gemm_specialization; + DlThreadConfig dl_thread_config; + DlThreadCluster dl_thread_cluster; + DlBlockTransferK0M0M1K1 dl_block_transfer_a; + DlBlockTransferK0N0N1K1 dl_block_transfer_b; + DlCThreadTransfer dl_c_thread_transfer; +}; +static_assert( + ckb::ConvAlgorithmDescriptor); +static_assert( + ckb::SpecifiesThreadBlock); +static_assert(ckb::SpecifiesFwdConcSpecialization< + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>); +static_assert( + ckb::SpecifiesGemmSpecialization); +static_assert( + ckb::SpecifiesDlThreadConfig); +static_assert( + ckb::SpecifiesDlThreadCluster); +static_assert( + ckb::SpecifiesDlBlockTransferA); +static_assert( + ckb::SpecifiesDlBlockTransferB); +static_assert( + ckb::SpecifiesDlCThreadTransfer); + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index 7fd02a56f7..14fae566f6 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -235,4 +235,149 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() EXPECT_NE(invoker_ptr, nullptr); } +template +constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK() +{ + // DL thread configuration + constexpr DlThreadConfig DlThreadCfg{ + .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; + + // DL thread cluster + constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; + + // DL A block transfer - K0_M0_M1_K1 format + constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + + // DL B block transfer - K0_N0_N1_K1 format + constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + + // DL C thread transfer + constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}; + + constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .dl_thread_config = DlThreadCfg, + .dl_thread_cluster = DlCluster, + .dl_block_transfer_a = DlBlockTransferA, + .dl_block_transfer_b = DlBlockTransferB, + .dl_c_thread_transfer = DlCTransfer}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + +// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +// Note: Large_Tensor has identical parameters to regular XDL CShuffle +template +constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor() +{ + constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 32, + .n_per_xdl = 32, + .m_xdl_per_wave = 2, + .n_xdl_per_wave = 1}; + + constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 16, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + + // Large_Tensor uses the same descriptor as regular XDL CShuffle + constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .gridwise_gemm = FwdGemmParams, + .block_transfer = FwdBlockTransfer, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .num_gemm_k_prefetch_stages = 1, + .num_groups_to_merge = 1, + .loop_scheduler = LoopScheduler::DEFAULT}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE( + kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + } // namespace ck_tile::builder::test_utils From d04eba4ae37c8c2d40855f02aa861e1ac1ec7b3f Mon Sep 17 00:00:00 2001 From: Xudong Yuan Date: Fri, 7 Nov 2025 08:45:41 +0800 Subject: [PATCH 8/9] Ck moe mxfp4 blockm32 (#3098) * block_m = 32 * ck block_m = 32 * aiter/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp format * mxfp4_moe v1 pipe * update format --------- Co-authored-by: zhimding Co-authored-by: lalala-sh Co-authored-by: felix --- .../moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp | 12 +- ...xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp | 3 +- ...ne_xdlops_b_preshuffle_mx_moe_selector.hpp | 24 +- ...pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp | 891 ++++++++++++++++++ ...pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp | 234 +++-- .../impl/device_moe_mx_gemm_bpreshuffle.hpp | 2 +- .../grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 468 +++++---- 7 files changed, 1357 insertions(+), 277 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp index 1adf039b70..ebb73ca7e0 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -181,7 +181,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 32; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -190,10 +190,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ScaleBlockSize, 256, - MPerBlock, 64, KPerBlock, + MPerBlock, 128, KPerBlock, 16, 16, 16, 16, - 4, 2, + 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, @@ -213,10 +213,10 @@ int main(int argc, char* argv[]) ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t N = 6144; - ck::index_t K = 4096; + ck::index_t N = 7168; + ck::index_t K = 256; ck::index_t experts = 8; - ck::index_t tokens = 832; + ck::index_t tokens = 208; ck::index_t topk = 2; if(argc == 1) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp index b3b3d312c7..b621c3a93d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp @@ -727,7 +727,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< }); }); - HotLoopScheduler(); + if constexpr(MPerBlock >= 64) + HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp index 6789d26a45..5223993671 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp" namespace ck { @@ -45,7 +46,28 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector() } else { - return nullptr; + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; } } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp new file mode 100644 index 0000000000..fc5cb60c37 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp @@ -0,0 +1,891 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::A_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + static constexpr auto async_vmcnt = + num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num; + static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves + + num_buffer_load_a_scale + num_buffer_load_b_scale; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 128) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + ignore = b_block_bufs; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc, + b_block_origin_idx, + b_thread_bufs(scale_mem_buf)); + + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [scale_comp_buf][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1)); + + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + // constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; + }); + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp index 2b936c8d25..7473d2f2e7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3 2) + { - // Group num_mfma_perstage num_ds_read_a_perstage - // since we want to reuse a local register buffer - constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; - constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; - constexpr auto num_ds_read_a_mfma_perstage = - math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); - constexpr auto num_ds_read_a_prefetch_stages = 2; + constexpr auto num_ds_read_a_prefetch_stages = 2; - constexpr auto buffer_load_perstage_more = - math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_less = - math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_stage2 = - math::integer_divide_floor((num_buffer_load_stage2), 2); + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_stage2 = + math::integer_divide_floor((num_buffer_load_stage2), 2); - constexpr auto buffer_load_stages_more = - num_buffer_load_stage1 - - math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * - ((num_total_stages - 2)); + constexpr auto buffer_load_stages_more = + num_buffer_load_stage1 - + math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * + ((num_total_stages - 2)); - constexpr auto buffer_load_issue_point_interval_more = - num_mfma_perstage / buffer_load_perstage_more; - constexpr auto buffer_load_issue_point_interval_less = - num_mfma_perstage / buffer_load_perstage_less; - constexpr auto buffer_load_issue_point_interval_stage2 = - num_mfma_perstage / buffer_load_perstage_stage2; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto buffer_load_issue_point_interval_stage2 = + num_mfma_perstage / buffer_load_perstage_stage2; - // Stage 1 - // global read more - static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // Stage 1 + // global read more + static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // global read less + static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // Stage 2, Sync + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + else + { + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + + num_buffer_load_b_scale; + constexpr auto num_dsread_a_mfma = math::integer_divide_ceil( + num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) { + static_for<0, mfma_perstage_more, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < + mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } - }); - }); - - // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + else { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read } }); - }); - - // Stage 2, Sync - // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); - }); + } } template ()) - { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); - } + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1249,7 +1246,6 @@ struct GridwiseMoeGemmMX_BPreshuffle __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; - return BlockwiseGemmPipe::BlockHasHotloop(num_loop); } @@ -1279,7 +1275,6 @@ struct GridwiseMoeGemmMX_BPreshuffle // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, // NPerBlock>; -#if 0 template @@ -1298,9 +1293,10 @@ struct GridwiseMoeGemmMX_BPreshuffle BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + ignore = a_element_op; ignore = b_element_op; - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1317,29 +1313,41 @@ struct GridwiseMoeGemmMX_BPreshuffle problem.NPadded, problem.StrideC); - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( - make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock), + // We pad the M unconditionaly for Scale + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / (KXdlPack * 64 / MPerXdl), - 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( make_tuple(problem.N / (NXdlPack * NPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / (KXdlPack * 64 / NPerXdl), - 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); - // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; if(expert_block_id * MPerBlock >= max_token_id) return; const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); - const auto block_mn = [&]() -> std::pair { if constexpr(NSwizzle) { @@ -1372,86 +1380,78 @@ struct GridwiseMoeGemmMX_BPreshuffle constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); constexpr auto AKThreads = AK0Threads * AK1Threads; constexpr auto AMRepeats = MPerBlock / AMThreads; - const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads; if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads]; index_t token_offset = fused_token & 0xffffff; if constexpr(!IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = static_cast(token_offset) * problem.K / APackedSize; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); + const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = - __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) * - math::integer_divide_ceil(problem.K, ScaleBlockSize)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack); + // Gride buffer creation const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + expert_id * expert_scale_stride, + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); // B matrix in LDS memory, dst of blockwise copy - // dummy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy - auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + + // A matrix blockwise direct to LDS copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad< ThisThreadBlock, - AElementwiseOperation, - ck::tensor_operation::element_wise::PassThrough, - InMemoryDataOperationEnum::Set, Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ADataType, - LDSTypeA, + ADataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, ABlockTransferSrcVectorDim, 2, ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, IndexType, - 1, - BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); + 1>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + gather_offsets); // Thread-wise copy // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack - auto b_block_buf = make_static_buffer( + auto b_block_buf_ping = make_static_buffer( b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2{}, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, + Sequence<0, 1, 2, 3, 4>, 4, BBlockTransferSrcScalarPerVector, BThreadTransferSrcResetCoordinateAfterRun, @@ -1472,16 +1472,16 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); // Blockwise GEMM pipeline static_assert(std::is_default_constructible_v); @@ -1505,13 +1505,16 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto waveId_m = wave_idx[I0]; const auto waveId_n = wave_idx[I1]; - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - auto thread_offset_shuffled = get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; auto a_thread_offset_m = waveId_m; + // get each thread's offset int the scale tensor + const index_t token_scale_pos = block_m_id * MPerBlock; + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< AScaleDataType, AScaleDataType, @@ -1538,7 +1541,7 @@ struct GridwiseMoeGemmMX_BPreshuffle Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths Sequence<0, 1, 2>, // DimAccessOrder 2, // SrcVectorDim - KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector 1, // SrcScalarStrideInVector true>(b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, @@ -1547,29 +1550,37 @@ struct GridwiseMoeGemmMX_BPreshuffle if constexpr(IsInputGemm) { - const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride / BPackedSize, + p_b_grid_up + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< - BDataType, - BDataType, - decltype(b_grid_desc_bpreshuffled), - decltype(b_block_desc_bk0_n_bk1), - Sequence{}, I1, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, - 3, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_grid_desc_bpreshuffled, - make_multi_index(n_block_data_idx_on_grid, - get_warp_local_1d_id() % NWave, - 0, - KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); - const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; - const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride, + auto b_blockwise_copy_up = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<0, 1, 2, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); + const BScaleDataType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< BScaleDataType, BScaleDataType, @@ -1587,25 +1598,30 @@ struct GridwiseMoeGemmMX_BPreshuffle thread_offset_shuffled / scale_pack_size_b)); blockwise_gemm_pipeline.template Run( + // A a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, + // Gate and Up b_grid_desc_bpreshuffled, b_block_desc_bk0_n_bk1, b_blockwise_copy, b_blockwise_copy_up, b_grid_buf, b_grid_buf_up, - b_block_buf, + b_block_bufs, b_block_slice_copy_step, + // C c_thread_buf, c_thread_buf_up, + // A scale a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf, + // B scale b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_thread_copy_up, @@ -1616,23 +1632,23 @@ struct GridwiseMoeGemmMX_BPreshuffle else { blockwise_gemm_pipeline.template Run( - a_grid_desc_ak0_m_ak1, + a_grid_desc_ak0_m_ak1, // A a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_bpreshuffled, + b_grid_desc_bpreshuffled, // B b_block_desc_bk0_n_bk1, b_blockwise_copy, b_grid_buf, - b_block_buf, + b_block_bufs, b_block_slice_copy_step, - c_thread_buf, - a_scale_grid_desc_am_ak, + c_thread_buf, // C + a_scale_grid_desc_am_ak, // A scale a_scale_thread_copy, a_scale_grid_buf, - b_scale_grid_desc_bn_ak, + b_scale_grid_desc_bn_ak, // B scale b_scale_thread_copy, b_scale_grid_buf, num_k_block_main_loop); @@ -1643,84 +1659,101 @@ struct GridwiseMoeGemmMX_BPreshuffle static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); // mul scales - static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4); + + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); + static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; - const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; vector_type topk_weights; // for gemm2 only - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(MulRoutedWeight) - { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - - if constexpr(IsInputGemm) // gu fusion - { - if constexpr(ActivationOperation == Activation::silu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - else if(ActivationOperation == Activation::gelu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - } - else - { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack + static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + + m0 * M2 * M1 * M3 * M4 * M5 + + m1 * M2 * M3 * M4 * M5 + + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = - topk_weights.AsType()[m4] * c_thread_buf_fp32[cidx]; + topk_weights = + *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); } - } + static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == + Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m5] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); }); }); }); @@ -1738,19 +1771,25 @@ struct GridwiseMoeGemmMX_BPreshuffle make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl M3, - M4)), + M4, + M5)), make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -1762,8 +1801,8 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), make_tuple(Sequence<0>{})); const auto m_thread_data_on_block_idx = @@ -1772,8 +1811,8 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0>{})); const auto n_thread_data_on_block_idx = @@ -1781,36 +1820,39 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; using EDataType = CDataType; @@ -1859,7 +1901,7 @@ struct GridwiseMoeGemmMX_BPreshuffle using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 1; // hack fix felix + constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1867,8 +1909,9 @@ struct GridwiseMoeGemmMX_BPreshuffle decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type Sequence<1, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1, @@ -1898,13 +1941,25 @@ struct GridwiseMoeGemmMX_BPreshuffle auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence Date: Fri, 7 Nov 2025 11:42:39 +0800 Subject: [PATCH 9/9] fix MX bpreshuffle gemm B grid descriptor dimension error. (#3170) --- .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 3d2ef9b6c4..7c5bd606b2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -429,8 +429,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t WaveSize = BlockSize / (MWave * NWave); constexpr index_t NkSwizzleNumber = Number{}; - return make_naive_tensor_descriptor_packed( - make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); + return make_naive_tensor_descriptor_packed(make_tuple( + math::integer_divide_ceil(N0, NWave * NXdlPack), NWave, NXdlPack, K0, NkSwizzleNumber)); } __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(