From b8527a92360496666ed6606e53ddc97e35dcf76e Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Wed, 5 Nov 2025 17:53:06 +0100 Subject: [PATCH 001/118] [CK_BUILDER] Convolution traits. (#3152) Added: 1. Convolution traits & unit tests 2. Update builder enumerators to have representation of Convolution Kernels properties. 3. Unified builder pipeline version & scheduler enumerators --- .../builder/conv_algorithm_concepts.hpp | 12 +- .../include/ck_tile/builder/conv_factory.hpp | 48 +- .../ck_tile/builder/reflect/conv_traits.hpp | 719 ++++++++++++++++++ .../builder/reflect/instance_traits.hpp | 11 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 1 + ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 1 + .../builder/reflect/instance_traits_util.hpp | 28 + .../builder/include/ck_tile/builder/types.hpp | 61 +- experimental/builder/test/CMakeLists.txt | 3 + .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 4 +- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 2 +- .../builder/test/conv/test_conv_traits.cpp | 316 ++++++++ .../test/impl/conv_algorithm_types.hpp | 10 +- .../test/utils/ckb_conv_test_common.hpp | 18 +- ...olution_backward_weight_specialization.hpp | 2 + 20 files changed, 1165 insertions(+), 81 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp create mode 100644 experimental/builder/test/conv/test_conv_traits.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 365835684e..e43f910a73 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) { // Concept for parameter that describe block GEMM problem. template concept BlockGemmDescriptor = requires(T t) { - { t.pipeline_version } -> std::convertible_to; - { t.scheduler } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; }; // Concept for parameters that describe a gridwise WMMA GEMM problem. @@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { { t.n_per_wmma } -> std::convertible_to; { t.m_wmma_per_wave } -> std::convertible_to; { t.n_wmma_per_wave } -> std::convertible_to; - { t.pipeline_version } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; }; // Concept for vectorized data transfer for convolution input tensors. @@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) { // Concept to check if struct specifies block GEMM. template concept SpecifiesBlockGemm = requires { - { T::block_gemm.pipeline_version } -> std::convertible_to; - { T::block_gemm.scheduler } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; }; template @@ -180,7 +180,7 @@ concept SpecifiesNumGroupsToMerge = requires { template concept SpecifiesLoopScheduler = requires { - { T::loop_scheduler } -> std::convertible_to; + { T::loop_scheduler } -> std::convertible_to; }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index a3932f524c..1ccc190ba2 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -297,42 +297,42 @@ constexpr BlockGemmSpec SetBlockGemm() ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; - if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE) + if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE) { scheduler = ck::BlockGemmPipelineScheduler::Intrawave; } - else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE) + else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE) { scheduler = ck::BlockGemmPipelineScheduler::Interwave; } else { - static_assert(false, "Unknown BlockGemmPipelineScheduler"); + static_assert(false, "Unknown PipelineScheduler"); } - if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1) + if constexpr(BG.pipeline_version == PipelineVersion::V1) { version = ck::BlockGemmPipelineVersion::v1; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2) + else if constexpr(BG.pipeline_version == PipelineVersion::V2) { version = ck::BlockGemmPipelineVersion::v2; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3) + else if constexpr(BG.pipeline_version == PipelineVersion::V3) { version = ck::BlockGemmPipelineVersion::v3; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4) + else if constexpr(BG.pipeline_version == PipelineVersion::V4) { version = ck::BlockGemmPipelineVersion::v4; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5) + else if constexpr(BG.pipeline_version == PipelineVersion::V5) { version = ck::BlockGemmPipelineVersion::v5; } else { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; @@ -442,17 +442,17 @@ consteval ck::LoopScheduler SetLoopScheduler() { constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; - if constexpr(loop_scheduler == LoopScheduler::DEFAULT) + if constexpr(loop_scheduler == PipelineScheduler::DEFAULT) { return ck::LoopScheduler::Default; } - else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE) + else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE) { return ck::LoopScheduler::Interwave; } else { - static_assert(false, "Unknown LoopScheduler"); + static_assert(false, "Unknown PipelineScheduler"); } } @@ -460,29 +460,29 @@ template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; - if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1) + if constexpr(pipeline_version == PipelineVersion::V1) { return ck::PipelineVersion::v1; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2) + else if constexpr(pipeline_version == PipelineVersion::V2) { return ck::PipelineVersion::v2; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3) + else if constexpr(pipeline_version == PipelineVersion::V3) { static_assert(false, "V3 is used only for stream-K."); } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4) + else if constexpr(pipeline_version == PipelineVersion::V4) { return ck::PipelineVersion::v4; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY) + else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY) { return ck::PipelineVersion::weight_only; } else { - static_assert(false, "Unknown GridwiseGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } } @@ -566,29 +566,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { constexpr auto version = ALGORITHM.pipeline_version; - if constexpr(version == BlockGemmPipelineVersion::V1) + if constexpr(version == PipelineVersion::V1) { return ck::BlockGemmPipelineVersion::v1; } - else if constexpr(version == BlockGemmPipelineVersion::V2) + else if constexpr(version == PipelineVersion::V2) { return ck::BlockGemmPipelineVersion::v2; } - else if constexpr(version == BlockGemmPipelineVersion::V3) + else if constexpr(version == PipelineVersion::V3) { return ck::BlockGemmPipelineVersion::v3; } - else if constexpr(version == BlockGemmPipelineVersion::V4) + else if constexpr(version == PipelineVersion::V4) { return ck::BlockGemmPipelineVersion::v4; } - else if constexpr(version == BlockGemmPipelineVersion::V5) + else if constexpr(version == PipelineVersion::V5) { return ck::BlockGemmPipelineVersion::v5; } else { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp new file mode 100644 index 0000000000..a74d77d155 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -0,0 +1,719 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile::reflect::conv { + +// Helper metafunctions to convert from ck enums to builder enums + +/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5). +/// @details This function maps CK's block GEMM pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. The pipeline version +/// determines the strategy used for data movement and computation overlap in the +/// GEMM kernel's main loop. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::BlockGemmPipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v3) + return V3; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == v5) + return V5; +} + +/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK PipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY). +/// @details This function maps CK's general pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. Note that this overload +/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion +/// variant, including support for specialized weight-only pipelines. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::PipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == weight_only) + return WEIGHT_ONLY; +} + +/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE). +/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the +/// builder framework's standardized scheduler enum. The scheduler determines how work +/// is distributed and synchronized within and across wavefronts during pipeline execution. +/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates +/// across multiple wavefronts. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::BlockGemmPipelineScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Intrawave) + return INTRAWAVE; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + +/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK LoopScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE). +/// @details This function maps CK's loop scheduler identifiers to the builder framework's +/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of +/// the main computational loop are scheduled across threads. DEFAULT uses the standard +/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved +/// performance in certain scenarios. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::LoopScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Default) + return DEFAULT; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + +/// @brief Helper structures for organizing trait data with domain-specific naming + +/// @brief Data tile dimensions processed by a workgroup. +/// @details This struct defines the M, N, and K dimensions of the data tile +/// that a single workgroup (thread block) is responsible for processing in the +/// underlying GEMM computation. +struct DataTileInfo +{ + int m; ///< M dimension of the tile processed by the workgroup (MPerBlock). + int n; ///< N dimension of the tile processed by the workgroup (NPerBlock). + int k; ///< K dimension of the tile processed by the workgroup (KPerBlock). +}; + +/// @brief Dimensions for an input data tile transfer. +/// @details Defines the shape of the input tile (A or B matrix) as it is +/// transferred from global memory to LDS. The tile is conceptually divided +/// into k0 and k1 dimensions. +struct InputTileTransferDimensions +{ + int k0; ///< The outer dimension of K, where K = k0 * k1. + int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix. + int k1; ///< The inner dimension of K, often corresponding to the vector load size from global + ///< memory. +}; + +/// @brief Parameters governing the transfer of an input tile. +/// @details This struct holds configuration details for how an input tile is +/// loaded from global memory into LDS, including thread clustering, memory +/// access patterns, and vectorization settings. +struct InputTileTransferParams +{ + int k1; ///< The inner K dimension size, often matching the vectorization width. + std::array + thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how + ///< many threads are arranged on each axis. + std::array thread_cluster_order; ///< The order of thread spatial distribution over the + ///< input tensor dimensions. + std::array src_access_order; ///< The order of accessing input tensor axes (e.g., which + ///< dimension to read first). + int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed + ///< (the contiguous dimension). + int src_scalar_per_vector; ///< The size of the vector access instruction; the number of + ///< elements accessed per thread per instruction. + int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1 + ///< dimension. + bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank + ///< conflicts. +}; + +/// @brief Complete information for an input tile transfer. +/// @details Combines the dimensional information and transfer parameters for +/// a full description of an input tile's journey from global memory to LDS. +struct InputTileTransferInfo +{ + InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile. + InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation. +}; + +/// @brief Parameters for the warp-level GEMM computation. +/// @details Defines the configuration of the GEMM operation performed by each +/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions. +struct WarpGemmParams +{ + int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl). + int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl). + int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per + ///< wavefront (MXdlPerWave). + int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per + ///< wavefront (NXdlPerWave). +}; + +/// @brief Parameters for shuffling data between warps (CShuffle optimization). +/// @details Configures how many MFMA instruction results are processed per +/// wave in each iteration of the CShuffle routine. +struct WarpShuffleParams +{ + int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave + ///< per shuffle iteration. + int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave + ///< per shuffle iteration. +}; + +/// @brief Information for the output tile transfer (CShuffle). +/// @details Describes how the final computed tile (C matrix) is written out from +/// LDS to global memory, including shuffling, thread clustering, and vectorization. +struct OutputTileTransferInfo +{ + WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling. + // m_block, m_wave_per_xdl, n_block, n_wave_per_xdl + std::array thread_cluster_dims; ///< The spatial thread distribution used for storing + ///< data into the output tensor. + int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the + ///< output tensor. +}; + +// Helper metafunctions to derive signature information from Instance types + +/// @brief Derives the convolution direction from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). +template +constexpr builder::ConvDirection conv_direction() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) + { + return builder::ConvDirection::FORWARD; + } + else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) + { + return builder::ConvDirection::BACKWARD_DATA; + } + else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) + { + return builder::ConvDirection::BACKWARD_WEIGHT; + } + else + { + return builder::ConvDirection::FORWARD; // Default fallback + } +} + +/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or +/// `builder::ConvBwdWeightSpecialization` enum value. +template +constexpr auto conv_spec() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { InstTraits::kConvForwardSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + + if constexpr(InstTraits::kConvForwardSpecialization == Default) + { + return builder::ConvFwdSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3) + { + return builder::ConvFwdSpecialization::FILTER_3x3; + } + } + else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + + if constexpr(InstTraits::kConvBwdDataSpecialization == Default) + { + return builder::ConvBwdDataSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + } + else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + + if constexpr(InstTraits::kConvBwdWeightSpecialization == Default) + { + return builder::ConvBwdWeightSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC) + { + return builder::ConvBwdWeightSpecialization::ODD_C; + } + } +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts. +template +constexpr auto conv_layout() +{ + using InstTraits = InstanceTraits; + using ALayout = typename InstTraits::ALayout; + using BLayout = typename InstTraits::BLayout; + using ELayout = typename InstTraits::ELayout; + + namespace ctc = ck::tensor_layout::convolution; + + if constexpr(InstTraits::kSpatialDim == 1) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout1D::GNWC_GKXC_GNWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NWGC_GKXC_NWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKXC_NGKW; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKCX_NGKW; + } + } + else if constexpr(InstTraits::kSpatialDim == 2) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW; + } + } + else if constexpr(InstTraits::kSpatialDim == 3) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW; + } + } +} + +/// @brief Derives the data type from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). +template +constexpr builder::DataType conv_data_type() +{ + using InstTraits = InstanceTraits; + using ADataType = typename InstTraits::ADataType; + + if constexpr(std::is_same_v) + { + return builder::DataType::FP16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::BF16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP32; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::I8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::U8; + } + else + { + // Default fallback + return builder::DataType::FP32; + } +} + +/// @brief Derives the elementwise operation from op type. +/// @tparam ElementwiseOp Elementwise operation functor type. +/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation. +template +constexpr builder::ElementwiseOperation elementwise_op() +{ + constexpr std::string_view name = detail::elementwise_op_name(); + if constexpr(detail::case_insensitive_equal(name, "Bias")) + { + return builder::ElementwiseOperation::BIAS; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasClamp")) + { + return builder::ElementwiseOperation::BIAS_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + { + return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) + { + return builder::ElementwiseOperation::BILINEAR; + } + else if constexpr(detail::case_insensitive_equal(name, "Clamp")) + { + return builder::ElementwiseOperation::CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Scale")) + { + return builder::ElementwiseOperation::SCALE; + } + else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + { + return builder::ElementwiseOperation::PASS_THROUGH; + } +} + +/// @brief Derives a gemm padding from a kernel instance type. +/// @tparam Instance - A Device Kernel object type. +/// @return A `builder::GemmPadding` enum value corresponding to kernel padding. +template +constexpr builder::GemmPadding gemm_spec() +{ + using InstTraits = InstanceTraits; + using enum builder::GemmPadding; + using enum ck::tensor_operation::device::GemmSpecialization; + + constexpr auto gemm_spec = InstTraits::kGemmSpecialization; + + if constexpr(gemm_spec == Default) + { + return DEFAULT; + } + else if constexpr(gemm_spec == MPadding) + { + return M_PADDING; + } + else if constexpr(gemm_spec == NPadding) + { + return N_PADDING; + } + else if constexpr(gemm_spec == KPadding) + { + return K_PADDING; + } + else if constexpr(gemm_spec == MNPadding) + { + return MN_PADDING; + } + else if constexpr(gemm_spec == MKPadding) + { + return MK_PADDING; + } + else if constexpr(gemm_spec == NKPadding) + { + return NK_PADDING; + } + else if constexpr(gemm_spec == MNKPadding) + { + return MNK_PADDING; + } + else if constexpr(gemm_spec == OPadding) + { + return O_PADDING; + } + else if constexpr(gemm_spec == MOPadding) + { + return MO_PADDING; + } + else if constexpr(gemm_spec == NOPadding) + { + return NO_PADDING; + } + else if constexpr(gemm_spec == KOPadding) + { + return KO_PADDING; + } + else if constexpr(gemm_spec == MNOPadding) + { + return MNO_PADDING; + } + else if constexpr(gemm_spec == MKOPadding) + { + return MKO_PADDING; + } + else if constexpr(gemm_spec == NKOPadding) + { + return NKO_PADDING; + } + else if constexpr(gemm_spec == MNKOPadding) + { + return MNKO_PADDING; + } +} + +/// @brief Primary template for extracting convolution traits. +/// @details This struct is the main entry point for reflecting on a convolution +/// kernel's properties. It is specialized to handle different kinds of input types. +template +struct ConvTraits; + +/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`. +/// @details This is the primary specialization used to extract a comprehensive +/// set of traits directly from a fully-formed device kernel `Instance` type. +/// It uses `InstanceTraits` to access the kernel's template parameters. +template + requires requires { typename InstanceTraits; } +struct ConvTraits +{ + using InstTraits = InstanceTraits; + + // --- Signature Information --- + /// @brief The number of spatial dimensions in the convolution (1, 2, or 3). + static constexpr int spatial_dim = InstTraits::kSpatialDim; + /// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight). + static constexpr builder::ConvDirection direction = conv_direction(); + /// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK). + static constexpr auto layout = conv_layout(); + /// @brief The primary data type used in the computation (e.g., FP16, FP32). + static constexpr builder::DataType data_type = conv_data_type(); + + static constexpr builder::ElementwiseOperation input_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation weight_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation output_element_op = + elementwise_op(); + + /// @brief The GEMM specialization used by the kernel - padding + static constexpr auto gemm_padding = gemm_spec(); + /// @brief The convolution-specific specialization (e.g., Default, 1x1). + static constexpr auto conv_specialization = conv_spec(); + + // --- Algorithm Information --- + /// @brief The total number of threads in a thread block (workgroup). + static constexpr int thread_block_size = InstTraits::kBlockSize; + /// @brief The dimensions of the data tile processed by the thread block. + static constexpr DataTileInfo tile_dims = { + .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; + + /// @brief Configuration for the A-matrix (input) tile transfer. + static constexpr InputTileTransferInfo a_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; + + /// @brief Configuration for the B-matrix (weights) tile transfer. + static constexpr InputTileTransferInfo b_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; + + /// @brief Parameters for the warp-level GEMM computation. + static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}; + + /// @brief Configuration for the C-matrix (output) tile transfer. + static constexpr OutputTileTransferInfo c_tile_transfer = { + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; + + /// @brief Helper to safely get the pipeline version. + /// @details This is only available for some convolutions (e.g., forward). + /// If not present in `InstanceTraits`, it returns a default value. + template + static constexpr auto get_pipeline_version() + { + if constexpr(requires { T::kPipelineVersion; }) + { + return convert_pipeline_version(); + } + else + { + // Return a default or indicate not available + return builder::PipelineVersion::V1; + } + } + + /// @brief The block GEMM pipeline version used by the kernel. + static constexpr auto pipeline_version = get_pipeline_version(); + + /// @brief Helper to safely get the pipeline scheduler. + /// @details This is only available for some convolutions. If not present + /// in `InstanceTraits`, it returns a default value. + template + static constexpr auto get_pipeline_scheduler() + { + if constexpr(requires { T::kPipelineScheduler; }) + { + return convert_pipeline_scheduler(); + } + else if constexpr(requires { T::kLoopScheduler; }) + { + return convert_pipeline_scheduler(); + } + else + { + // Return a default or indicate not available + return builder::PipelineScheduler::DEFAULT; + } + } + + /// @brief The pipeline scheduler used by the kernel. + static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); +}; + +/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type. +/// @details This specialization provides backward compatibility for reflecting +/// on kernels defined via the `ConvBuilder` interface. It works by first +/// creating the `Instance` via the builder's factory, and then delegating +/// all trait extraction to the `ConvTraits` specialization. +template +struct ConvTraits> +{ + using Factory = builder::ConvFactory; + using Instance = typename Factory::Instance; + + // Delegate to Instance-based ConvTraits + using InstanceConvTraits = ConvTraits; + + // Forward all members from Instance-based traits + static constexpr int spatial_dim = InstanceConvTraits::spatial_dim; + static constexpr builder::ConvDirection direction = InstanceConvTraits::direction; + static constexpr auto layout = InstanceConvTraits::layout; + static constexpr builder::DataType data_type = InstanceConvTraits::data_type; + + static constexpr builder::ElementwiseOperation input_element_op = + InstanceConvTraits::input_element_op; + static constexpr builder::ElementwiseOperation weight_element_op = + InstanceConvTraits::weight_element_op; + static constexpr builder::ElementwiseOperation output_element_op = + InstanceConvTraits::output_element_op; + + static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding; + static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization; + + static constexpr int thread_block_size = InstanceConvTraits::thread_block_size; + static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims; + static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer; + static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer; + static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm; + static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer; + static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version; + static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler; +}; + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp index c9b45691cc..07f1b94b07 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp @@ -14,18 +14,9 @@ #pragma once -#include #include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include "instance_traits_util.hpp" +#include namespace ck_tile::reflect { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 7eb4c883b0..d6fc6da0d6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -15,6 +15,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 60f991b1fc..9edfa4d4c9 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -14,6 +14,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 95d1c94de4..c863d2306c 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -9,9 +9,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -371,4 +373,30 @@ constexpr std::string type_or_type_tuple_name() } } +/// @brief Makes a case insensitive comparison of two string views. +/// @param a First string view +/// @param b Second string view +/// @return Whether two string views a equal case insensitive +constexpr bool case_insensitive_equal(std::string_view a, std::string_view b) +{ + if(a.size() != b.size()) + return false; + + for(size_t i = 0; i < a.size(); ++i) + { + char c1 = a[i]; + char c2 = b[i]; + + // Convert to lowercase for comparison + if(c1 >= 'A' && c1 <= 'Z') + c1 += 32; + if(c2 >= 'A' && c2 <= 'Z') + c2 += 32; + + if(c1 != c2) + return false; + } + return true; +} + } // namespace ck_tile::reflect::detail diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 2650f0de16..2af10346e5 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -128,29 +128,14 @@ enum class ElementwiseOperation PASS_THROUGH }; -// Enums for the current block GEMM pipeline versions. -enum class BlockGemmPipelineVersion +// Enums for pipeline versions & schedulers +enum class PipelineVersion { V1, V2, V3, V4, - V5 -}; - -enum struct BlockGemmPipelineScheduler -{ - INTRAWAVE, - INTERWAVE, -}; - -// Enums for the gridwise GEMM pipeline versions. -enum class GridwiseGemmPipelineVersion -{ - V1, - V2, - V3, // Only used in stream-K implementation - V4, + V5, WEIGHT_ONLY }; @@ -186,9 +171,47 @@ enum class ConvFwdSpecialization FILTER_3x3 }; -enum class LoopScheduler +// Enums for the backward data convolution specialization. +enum class ConvBwdDataSpecialization { DEFAULT, + FILTER_1X1_STRIDE1_PAD0, +}; + +// Enums for the backward weight convolution specialization. +enum class ConvBwdWeightSpecialization +{ + DEFAULT, + FILTER_1X1_STRIDE1_PAD0, + FILTER_1X1_PAD0, + ODD_C, +}; + +// Enums for the Gemm padding. +enum class GemmPadding +{ + DEFAULT, + M_PADDING, + N_PADDING, + K_PADDING, + MN_PADDING, + MK_PADDING, + NK_PADDING, + MNK_PADDING, + O_PADDING, + MO_PADDING, + NO_PADDING, + KO_PADDING, + MNO_PADDING, + MKO_PADDING, + NKO_PADDING, + MNKO_PADDING, +}; + +enum class PipelineScheduler +{ + DEFAULT, + INTRAWAVE, INTERWAVE }; diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 8b5c4519a9..0cb3237f8c 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -64,6 +64,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) +add_ck_builder_test(test_conv_traits + conv/test_conv_traits.cpp) + # Function to add all test_ckb targets to a list function(collect_test_ckb_targets result_var) # Get all targets in current directory diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index b58de836de..123034eb77 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -27,7 +27,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V2, + PipelineVersion::V2, ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index b8dbf2ca97..240746f546 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } @@ -47,7 +47,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index aba2f29ffd..6366016707 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V3, + PipelineVersion::V3, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index 4d01323600..7b303a7bde 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V4, + PipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index a30158aa8e..b40dd0b0d7 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -25,7 +25,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index c0b2e613a2..e0dad4e1a1 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V4, + PipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 0fea260eac..43ffb3f89a 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -26,7 +26,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V1, + PipelineVersion::V1, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/test_conv_traits.cpp new file mode 100644 index 0000000000..ca453d2ad4 --- /dev/null +++ b/experimental/builder/test/conv/test_conv_traits.cpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +using ::testing::ElementsAre; + +// Test fixture for ConvTraits tests +class ConvTraitsTest : public ::testing::Test +{ +}; + +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::PipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + false>; // DirectLoad + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(Traits::warp_gemm.gemm_m, 32); + EXPECT_EQ(Traits::warp_gemm.gemm_n, 32); + EXPECT_EQ(Traits::warp_gemm.m_iter, 4); + EXPECT_EQ(Traits::warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE); + EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} +// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} +} // anonymous namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 1a78028862..accc4048dc 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -49,14 +49,14 @@ struct GridwiseWmmaGemm size_t n_per_wmma = 0; size_t m_wmma_per_wave = 0; size_t n_wmma_per_wave = 0; - GridwiseGemmPipelineVersion pipeline_version; + PipelineVersion pipeline_version; }; static_assert(ckb::GridwiseWmmaGemmDescriptor); struct BlockGemm { - BlockGemmPipelineVersion pipeline_version; - BlockGemmPipelineScheduler scheduler; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; }; static_assert(ckb::BlockGemmDescriptor); @@ -156,7 +156,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; size_t num_groups_to_merge; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; static_assert( ckb::ConvAlgorithmDescriptor); @@ -191,7 +191,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ConvFwdSpecialization fwd_specialization; GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; static_assert( ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index d145bdfd6c..7fd02a56f7 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -16,7 +16,7 @@ using namespace test; // Common test implementation template constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() { @@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() .src_access_order_b = {1, 0, 2}}; constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion, - .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + .scheduler = PipelineScheduler::INTRAWAVE}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock, @@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")); // Verify pipeline version is correct - if(FwdPipelineVersion == BlockGemmPipelineVersion::V1) + if(FwdPipelineVersion == PipelineVersion::V1) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3) + else if(FwdPipelineVersion == PipelineVersion::V3) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4) + else if(FwdPipelineVersion == PipelineVersion::V4) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5) + else if(FwdPipelineVersion == PipelineVersion::V5) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos); // Verify specialization is correct @@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle() .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, .num_groups_to_merge = 2, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; @@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1, - .pipeline_version = GridwiseGemmPipelineVersion::V1}; + .pipeline_version = PipelineVersion::V1}; constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, @@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() .fwd_specialization = FwdConvSpecialization, .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp index 01bb806789..219206c5ce 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -3,6 +3,8 @@ #pragma once +#include + namespace ck { namespace tensor_operation { namespace device { From 4533aa6dbab648adc1a496b6064cb79777c41cf5 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:42:22 -0800 Subject: [PATCH 002/118] Fix compilation errors with clang22. (#3164) * resolve compilation issue with clang22 * add __extension__ for __COUNTER__ usage in ck_tile --- include/ck_tile/core/utility/static_counter.hpp | 13 ++++++++----- profiler/src/profiler_operation_registry.hpp | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/core/utility/static_counter.hpp b/include/ck_tile/core/utility/static_counter.hpp index 84af3dd52f..4828e2e010 100644 --- a/include/ck_tile/core/utility/static_counter.hpp +++ b/include/ck_tile/core/utility/static_counter.hpp @@ -102,11 +102,14 @@ struct static_counter_uniq_; } #define MAKE_SC() \ - ck_tile::static_counter> {} -#define MAKE_SC_WITH(start_, step_) \ - ck_tile::static_counter, start_, step_> {} -#define NEXT_SC(c_) c_.next<__COUNTER__>() -#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>() + __extension__ ck_tile::static_counter> {} +#define MAKE_SC_WITH(start_, step_) \ + __extension__ ck_tile:: \ + static_counter, start_, step_> \ + { \ + } +#define NEXT_SC(c_) __extension__ c_.next<__COUNTER__>() +#define NEXT_SCI(c_, static_i_) __extension__ c_.next<__COUNTER__ + static_i_>() // Usage: // constexpr auto c = MAKE_SC() diff --git a/profiler/src/profiler_operation_registry.hpp b/profiler/src/profiler_operation_registry.hpp index 276b7b38dc..7e6d22d4ce 100644 --- a/profiler/src/profiler_operation_registry.hpp +++ b/profiler/src/profiler_operation_registry.hpp @@ -74,6 +74,6 @@ class ProfilerOperationRegistry final #define PP_CONCAT(x, y) PP_CONCAT_IMPL(x, y) #define PP_CONCAT_IMPL(x, y) x##y -#define REGISTER_PROFILER_OPERATION(name, description, operation) \ - static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \ +#define REGISTER_PROFILER_OPERATION(name, description, operation) \ + __extension__ static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \ ::ProfilerOperationRegistry::GetInstance().Add(name, description, operation) From 12922120d2567c3512048d7e8ed37e387a07bab6 Mon Sep 17 00:00:00 2001 From: joyeamd Date: Thu, 6 Nov 2025 14:29:03 +0800 Subject: [PATCH 003/118] add gfx11's barrier following SPG's reference (#3159) * add gfx11's barrier following SPG's reference * re-format the code * minor fix --------- Co-authored-by: ThomasNing --- include/ck_tile/core/arch/arch.hpp | 129 +++++++++++++++++++---------- 1 file changed, 83 insertions(+), 46 deletions(-) mode change 100644 => 100755 include/ck_tile/core/arch/arch.hpp diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp old mode 100644 new mode 100755 index 8620e7337c..5bf8548470 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -136,66 +136,103 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) #endif } -// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html +struct WaitcntLayoutGfx12 +{ // s_wait_loadcnt_dscnt: mem[13:8], ds[5:0] + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // mem + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // ds + CK_TILE_DEVICE static constexpr bool HAS_EXP = false; + + CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 8); } + CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 0); } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; } +}; + +struct WaitcntLayoutGfx11 +{ // vm[15:10] (6), lgkm[9:4] (6), exp unused + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; + CK_TILE_DEVICE static constexpr bool HAS_EXP = false; + + CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); } + CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; } +}; + +struct WaitcntLayoutLegacy +{ // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2 + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8] + CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4] + CK_TILE_DEVICE static constexpr bool HAS_EXP = true; + + CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) + { + c &= VM_MASK; + return ((c & 0xF) << 0) | ((c & 0x30) << 10); + } + CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 8); } + CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return ((c & EXP_MASK) << 4); } +}; + +// Select active layout +#if defined(__gfx12__) +using Waitcnt = WaitcntLayoutGfx12; +#elif defined(__gfx11__) +using Waitcnt = WaitcntLayoutGfx11; +#else +using Waitcnt = WaitcntLayoutLegacy; +#endif + +//---------------------------------------------- +// Public API: only from_* (constexpr templates) +//---------------------------------------------- struct waitcnt_arg { -#if defined(__gfx12__) - // use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8] - CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111; - - CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111; - CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111; - CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111; - - template - CK_TILE_DEVICE static constexpr index_t from_vmcnt() - { - static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); - return MAX & (cnt << 8); - } - - template - CK_TILE_DEVICE static constexpr index_t from_expcnt() - { - return 0; // no export in MI series - } - - template - CK_TILE_DEVICE static constexpr index_t from_lgkmcnt() - { - static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); - return MAX & cnt; - } + // kMax* exposed for callers; match field widths per-arch +#if defined(__gfx12__) || defined(__gfx11__) + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none #else - // bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210 - // [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV - CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111; - - CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111; - CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111; - CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111; + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split) + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits +#endif template CK_TILE_DEVICE static constexpr index_t from_vmcnt() { - static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); - return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10)); - } - - template - CK_TILE_DEVICE static constexpr index_t from_expcnt() - { - static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]"); - return MAX & (cnt << 4); + static_assert((cnt & ~Waitcnt::VM_MASK) == 0, "vmcnt out of range"); + return Waitcnt::pack_vm(cnt); } template CK_TILE_DEVICE static constexpr index_t from_lgkmcnt() { - static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]"); - return MAX & (cnt << 8); + static_assert((cnt & ~Waitcnt::LGKM_MASK) == 0, "lgkmcnt out of range"); + return Waitcnt::pack_lgkm(cnt); } + + template + CK_TILE_DEVICE static constexpr index_t from_expcnt() + { + if constexpr(Waitcnt::HAS_EXP) + { + // EXP_MASK only exists on legacy +#if !defined(__gfx12__) && !defined(__gfx11__) + static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range"); + return Waitcnt::pack_exp(cnt); +#else + (void)cnt; + return 0; #endif + } + else + { + static_assert(cnt == 0, "expcnt unsupported on this arch"); + return 0; + } + } }; template Date: Thu, 6 Nov 2025 11:26:30 +0100 Subject: [PATCH 004/118] [CK TILE] Convolution remove magic values (#3160) * [CK TILE] Refactor Conv configs and Conv Elementwise * fix * [CK TILE] Convolution remove magix values * fix partitioner --- .../20_grouped_convolution/conv_configs.hpp | 16 +- ...uped_convolution_backward_data_invoker.hpp | 209 +++++++++--------- ...ed_convolution_backward_weight_invoker.hpp | 90 ++++---- ...tion_backward_weight_two_stage_invoker.hpp | 108 +++++---- .../grouped_convolution_forward_invoker.hpp | 96 ++++---- ...nvolution_forward_large_tensor_invoker.hpp | 119 +++++----- .../utils/grouped_convolution_utils.hpp | 69 ++++-- 7 files changed, 355 insertions(+), 352 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index b9edf247cc..8a2a60a197 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -14,19 +14,11 @@ struct ConvConfigBase { - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; - static constexpr bool kPadK = true; - - static constexpr bool TransposeC = false; - static constexpr ck_tile::index_t VectorSizeA = 4; static constexpr ck_tile::index_t VectorSizeB = 8; static constexpr ck_tile::index_t VectorSizeC = 8; - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; @@ -210,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; + static constexpr ck_tile::index_t NumWaveGroups = 2; }; template diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index 14a533ffc9..d19d3ac8ec 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -22,8 +22,6 @@ struct GroupedConvolutionBackwardDataInvoker static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, const ck_tile::stream_config& s) { - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -32,36 +30,33 @@ struct GroupedConvolutionBackwardDataInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdData, + typename GroupedConvTraitsType::BsLayoutBwdData, + typename GroupedConvTraitsType::CLayoutBwdData, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardDataInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, InDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -93,95 +89,96 @@ struct GroupedConvolutionBackwardDataInvoker const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ConvConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ConvConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - auto preprocess = [&]() { - ck_tile::hip_check_error(hipMemsetAsync( - kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); - }; - - ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; + auto preprocess = [&]() { + ck_tile::hip_check_error(hipMemsetAsync( + kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index 0e777c5f8a..81b9d402ce 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -21,8 +21,6 @@ struct GroupedConvolutionBackwardWeightInvoker static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, const ck_tile::stream_config& s) { - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -31,37 +29,34 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA; - constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB; - constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker InDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -101,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -127,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker AccDataType, WeiDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -136,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index a8e41438c8..8cef2bde65 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -23,8 +23,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker { using WorkspaceDataType = float; - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -33,36 +31,34 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 4; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -70,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker InDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, WeiDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -102,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -128,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker AccDataType, WorkspaceDataType, // C: Workspace normally Out typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -139,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker ConvConfig::K_Warp_Tile, GemmPipelineProblem::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel(Kernel{}, grids, blocks, 0, kargs), - ck_tile::make_kernel(ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(shape[1], 1), // Input Stride - ck_tile::make_tuple(shape[1], 1), // Output Stride - input_tensors, - static_cast(c_ptr))); + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), + ck_tile::make_kernel( + ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr))); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 2290f60d1f..7c8269d13c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -32,8 +32,6 @@ struct GroupedConvolutionForwardInvoker { std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; } - constexpr int kBlockPerCu = 1; - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -42,38 +40,34 @@ struct GroupedConvolutionForwardInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - constexpr ck_tile::index_t NumGroupsToMerge = 1; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + ConvConfig::VectorSizeA, + ConvConfig::VectorSizeB, + ConvConfig::VectorSizeC, + ConvConfig::NumGroupsToMerge>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsType::AsLayoutFwd, + typename GroupedConvTraitsType::BsLayoutFwd, + typename GroupedConvTraitsType::CLayoutFwd, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -81,13 +75,14 @@ struct GroupedConvolutionForwardInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, OutDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -116,21 +111,21 @@ struct GroupedConvolutionForwardInvoker constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -142,7 +137,7 @@ struct GroupedConvolutionForwardInvoker AccDataType, OutDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -151,10 +146,10 @@ struct GroupedConvolutionForwardInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; using Kernel = ck_tile::GroupedConvolutionForwardKernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 4d983baac5..9d2752727c 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -25,7 +25,6 @@ struct GroupedConvolutionForwardInvoker { std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; } - constexpr int kBlockPerCu = 1; // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< @@ -35,27 +34,18 @@ struct GroupedConvolutionForwardInvoker ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; - - using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits; + using GroupedConvTraitsTypeDefault = + ck_tile::GroupedConvTraits; using GroupedConvTraitsTypeLargeTensor = ck_tile::GroupedConvTraits; + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - ConvConfig::kPadM, - ConvConfig::kPadN, - ConvConfig::kPadK, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN, + GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK, ConvConfig::DoubleSmemBuffer, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout, - ConvConfig::TransposeC, - false, - false, // Persistent, + typename GroupedConvTraitsTypeDefault::AsLayoutFwd, + typename GroupedConvTraitsTypeDefault::BsLayoutFwd, + typename GroupedConvTraitsTypeDefault::CLayoutFwd, + GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC, + GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent, ConvConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< @@ -88,13 +83,14 @@ struct GroupedConvolutionForwardInvoker WeiDataType, AccDataType, GemmShape, - typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd, + typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd< + ConvConfig::NumWaveGroups>, ck_tile::element_wise::PassThrough, ck_tile::element_wise::PassThrough, OutDataType, - true, - VectorSizeA, - VectorSizeB>; + GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsTypeDefault::VectorSizeA, + GroupedConvTraitsTypeDefault::VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template UniversalGemmPipeline; @@ -116,9 +112,9 @@ struct GroupedConvolutionForwardInvoker using TransformType = ck_tile::TransformConvFwdToGemm; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< ConvConfig::Pipeline>::template GemmPipeline; @@ -290,7 +286,7 @@ struct GroupedConvolutionForwardInvoker AccDataType, OutDataType, typename GroupedConvTraitsType::ImplicitGemmDsLayout, - ck_tile::tensor_layout::gemm::RowMajor, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, CDEElementWise, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, @@ -299,10 +295,10 @@ struct GroupedConvolutionForwardInvoker ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, - ConvConfig::TransposeC, + GroupedConvTraitsType::FixedGemmParams::TransposeC, memory_operation, - 1, - true, + ConvConfig::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; // Use split-image kernel if layout supports it, otherwise use regular kernel @@ -368,7 +364,8 @@ struct GroupedConvolutionForwardInvoker } ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 9b5a60ee1f..8ea6cffa7d 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -74,6 +74,21 @@ struct GroupedConvTraits } public: + // Fixed values for Implicit GEMM + struct FixedGemmParams + { + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + static constexpr bool TransposeC = false; + static constexpr bool FixedVectorSize = true; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Persistent = false; + using ELayout = ck_tile::tensor_layout::gemm::RowMajor; + }; + // Compile time parameters static constexpr bool EnableSplitImage = EnableSplitImage_; static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_; static constexpr index_t NDimSpatial = NDimSpatial_; @@ -82,31 +97,43 @@ struct GroupedConvTraits using WeiLayout = WeiLayout_; using DsLayout = DsLayout_; using OutLayout = OutLayout_; + + // Forward Gemm Layouts + using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; + using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor; + // Backward Data Gemm Layouts + using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor; + // Backward Weight Gemm Layouts + using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor; + using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + + template using GroupedConvImplicitGemmTraitsFwd = - TileGemmTraits; - using GroupedConvImplicitGemmTraitsBwdData = - TileGemmTraits; - using GroupedConvImplicitGemmTraitsBwdWeight = - TileGemmTraits; + TileGemmTraits; + template + using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits; + template + using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits; static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_; static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_; static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_; - static constexpr index_t NumDTensor = DsLayout::size(); + static constexpr ck_tile::index_t NumDTensor = DsLayout::size(); using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); }; From 18e083003fa25a661015542c39b1979200f361cf Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:46:26 +0100 Subject: [PATCH 005/118] [CK_BUILDER] Convolution description (#3163) * Add DirectLoad tparam & clean up headers. * Add convolution traits. * Update inline documentation. * Add more convolution specialization and gemm padding types. * Add additional helper functions & more tests to conv traits. * Fix tests cmake file. * Add case insensitive string comparison * Fix function name overlapping with variable name. * Unify pipeline version and scheduler enums. * Fix includes. * Update test conv traits with unified enums. * Update concepts etc with update unified enum * Fix ckb conv fwd test - unified enum usage. * Dump changes. * Add ostream overloads for all enum classes. * Update detailed() function in ConvDescription * Fix handling union based conv direction. * Add test & update conv description. * Refine tree view. * Update copyrights * Fix merge artifacts * Update detailed tree conv description * Fix clang-format --- .../include/ck_tile/builder/builder_utils.hpp | 62 ---- .../builder/conv_signature_predicates.hpp | 16 + .../builder/reflect/conv_description.hpp | 268 +++++++++++++++++ .../ck_tile/builder/reflect/conv_traits.hpp | 2 +- .../builder/reflect/instance_traits_util.hpp | 5 +- .../builder/reflect/tree_formatter.hpp | 106 +++++++ .../builder/include/ck_tile/builder/types.hpp | 275 ++++++++++++++++++ experimental/builder/test/CMakeLists.txt | 3 + .../builder/test/test_conv_description.cpp | 169 +++++++++++ 9 files changed, 842 insertions(+), 64 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp create mode 100644 experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp create mode 100644 experimental/builder/test/test_conv_description.cpp diff --git a/experimental/builder/include/ck_tile/builder/builder_utils.hpp b/experimental/builder/include/ck_tile/builder/builder_utils.hpp index 5b4981c630..f16d96bec6 100644 --- a/experimental/builder/include/ck_tile/builder/builder_utils.hpp +++ b/experimental/builder/include/ck_tile/builder/builder_utils.hpp @@ -78,66 +78,4 @@ struct UnsupportedEnumValue { }; -// Helper functions to convert enums to strings -constexpr std::string_view ConvDirectionToString(ConvDirection dir) -{ - switch(dir) - { - case ConvDirection::FORWARD: return "Forward"; - case ConvDirection::BACKWARD_DATA: return "Backward Data"; - case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight"; - default: return "Unknown"; - } -} - -constexpr std::string_view DataTypeToString(DataType dt) -{ - switch(dt) - { - case DataType::FP16: return "FP16"; - case DataType::FP32: return "FP32"; - case DataType::BF16: return "BF16"; - case DataType::FP8: return "FP8"; - case DataType::I8: return "I8"; - case DataType::U8: return "U8"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout1D layout) -{ - switch(layout) - { - case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK"; - case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK"; - case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW"; - case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout2D layout) -{ - switch(layout) - { - case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK"; - case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK"; - case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW"; - case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW"; - default: return "Unknown"; - } -} - -constexpr std::string_view LayoutToString(GroupConvLayout3D layout) -{ - switch(layout) - { - case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK"; - case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK"; - case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW"; - case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW"; - default: return "Unknown"; - } -} - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp index f016a342d3..3869c7b538 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp @@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); // Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); // Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); // Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. template concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + ConvDirectionIsForward && (Sig.device_operation._fwd == FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); @@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward = // Predicate for DeviceGroupedConvBwdWeight operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); // Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); // Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); // Predicate for DeviceGroupedConvBwdWeightMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); // Predicate for DeviceGroupedConvBwdWeight_Dl operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = + ConvDirectionIsBackwardWeight && (Sig.device_operation._bwd_weight == BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); @@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight = // Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); // Predicate for DeviceGroupedConvBwdDataMultipleD operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); // Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. template concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = + ConvDirectionIsBackwardData && (Sig.device_operation._bwd_data == BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp new file mode 100644 index 0000000000..0b58f5a3b7 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +/// @file conv_description.hpp +/// @brief Provides human-readable descriptions of ConvBuilder configurations + +namespace ck_tile::reflect::conv { + +struct ConvSignatureInfo +{ + int spatial_dim; + builder::ConvDirection direction; + std::variant + layout; + builder::DataType data_type; + builder::ElementwiseOperation input_element_op; + builder::ElementwiseOperation weight_element_op; + builder::ElementwiseOperation output_element_op; +}; + +// Algorithm information - groups all algorithm-related configuration +struct GemmAlgorithmInfo +{ + int thread_block_size; + DataTileInfo tile_dims; + WarpGemmParams warp_gemm; + InputTileTransferInfo a_tile_transfer; + InputTileTransferInfo b_tile_transfer; + OutputTileTransferInfo c_tile_transfer; + builder::PipelineVersion pipeline_version; + builder::PipelineScheduler pipeline_scheduler; + std::variant + conv_specialization; + builder::GemmPadding padding; +}; + +// Provides human-readable descriptions of ConvBuilder configurations. +struct ConvDescription +{ + ConvSignatureInfo signature; + GemmAlgorithmInfo algorithm; + + // Brief one-line summary + std::string brief() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " << signature.direction << " convolution"; + return oss.str(); + } + + // Detailed hierarchical description + std::string detailed() const + { + TreeFormatter f; + f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel"); + f.writeLine(1, "Signature"); + f.writeLine(2, "Tensor Type: ", signature.data_type); + f.writeLine(2, "Memory Layout: ", signature.layout); + f.writeLine(2, "Input elementwise operation: ", signature.input_element_op); + f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op); + f.writeLast(2, "Output elementwise operation: ", signature.output_element_op); + + f.writeLine(1, "Algorithm"); + // Compute Block section + f.writeLine(2, "Thread block size: ", algorithm.thread_block_size); + f.writeLine(2, + "Data tile size: ", + algorithm.tile_dims.m, + "×", + algorithm.tile_dims.n, + "×", + algorithm.tile_dims.k); + f.writeLine(2, "Gemm padding: ", algorithm.padding); + f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization); + // Pipeline section + f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version); + f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler); + f.writeLine(2, "Warp Gemm parameters: "); + f.writeLine( + 3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n); + f.writeLast(3, + "Number of warp gemm iterations: ", + algorithm.warp_gemm.m_iter, + "×", + algorithm.warp_gemm.n_iter); + + // Memory Access section + f.writeLine(2, "Memory access:"); + + f.writeLine(3, "A Tile transfer: "); + f.writeLine(4, + "Tile dimensions: ", + algorithm.a_tile_transfer.tile_dimensions.k0, + "×", + algorithm.a_tile_transfer.tile_dimensions.m_or_n, + "×", + algorithm.a_tile_transfer.tile_dimensions.k1, + "×"); + f.writeLine( + 4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1); + f.writeLine(4, + "Spatial thread distribution over the data tile: ", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]); + f.writeLine(4, + "The order of accessing data tile axes: ", + algorithm.a_tile_transfer.transfer_params.src_access_order[0], + "×", + algorithm.a_tile_transfer.transfer_params.src_access_order[1], + "×", + algorithm.a_tile_transfer.transfer_params.src_access_order[2]); + f.writeLine(4, + "Vectorized memory access axis index (with contiguous memory): ", + algorithm.a_tile_transfer.transfer_params.src_vector_dim); + f.writeLine(4, + "Vector access (GMEM read) instruction size: ", + algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector); + f.writeLine(4, + "Vector access (LDS write) instruction size: ", + algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + f.writeLast(4, + "LDS data layout padding (to prevent bank conflicts): ", + algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + + f.writeLine(3, "B Tile transfer: "); + f.writeLine(4, + "Tile dimensions: ", + algorithm.b_tile_transfer.tile_dimensions.k0, + "×", + algorithm.b_tile_transfer.tile_dimensions.m_or_n, + "×", + algorithm.b_tile_transfer.tile_dimensions.k1, + "×"); + f.writeLine( + 4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1); + f.writeLine(4, + "Spatial thread distribution over the data tile: ", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0], + "×", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1], + "×", + algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]); + f.writeLine(4, + "The order of accessing data tile axes: ", + algorithm.b_tile_transfer.transfer_params.src_access_order[0], + "×", + algorithm.b_tile_transfer.transfer_params.src_access_order[1], + "×", + algorithm.b_tile_transfer.transfer_params.src_access_order[2]); + f.writeLine(4, + "Vectorized memory access axis index (with contiguous memory): ", + algorithm.b_tile_transfer.transfer_params.src_vector_dim); + f.writeLine(4, + "Vector access (GMEM read) instruction size: ", + algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector); + f.writeLine(4, + "Vector access (LDS write) instruction size: ", + algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + f.writeLast(4, + "LDS data layout padding (to prevent bank conflicts): ", + algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + + f.writeLast(3, "C Tile transfer: "); + f.writeLine(4, + "Data shuffle (number of gemm instructions per iteration): ", + algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, + "×", + algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle); + f.writeLine(4, + "Spatial thread distribution used to store data: ", + algorithm.c_tile_transfer.thread_cluster_dims[0], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[1], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[2], + "×", + algorithm.c_tile_transfer.thread_cluster_dims[3]); + f.writeLast(4, + "Vector access (GMEM write) instruction size: ", + algorithm.c_tile_transfer.scalar_per_vector); + f.writeLast(2); + f.writeLast(1); + return f.getString(); + } + + // Educational explanation of optimization choices + std::string explain() const + { + std::ostringstream oss; + // Placeholder for future implementation + return oss.str(); + } + + // Performance characteristics and use case guidance + std::string suggest() const + { + std::ostringstream oss; + // Placeholder for future implementation + return oss.str(); + } +}; + +// Helper concept to detect if a type has InstanceTraits specialization +template +concept HasInstanceTraits = requires { typename InstanceTraits; }; + +// Helper concept to detect ConvBuilder types +template +concept IsConvBuilder = requires { + typename T::Factory; + typename T::Instance; +}; + +// Primary factory function: Create ConvDescription from Instance type directly +template + requires HasInstanceTraits +ConvDescription Describe() +{ + using Traits = ConvTraits; + + return ConvDescription{ + .signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim, + .direction = Traits::direction, + .layout = Traits::layout, + .data_type = Traits::data_type, + .input_element_op = Traits::input_element_op, + .weight_element_op = Traits::weight_element_op, + .output_element_op = Traits::output_element_op}, + .algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size, + .tile_dims = Traits::tile_dims, + .warp_gemm = Traits::warp_gemm, + .a_tile_transfer = Traits::a_tile_transfer, + .b_tile_transfer = Traits::b_tile_transfer, + .c_tile_transfer = Traits::c_tile_transfer, + .pipeline_version = Traits::pipeline_version, + .pipeline_scheduler = Traits::pipeline_scheduler, + .conv_specialization = Traits::conv_specialization, + .padding = Traits::gemm_padding}}; +} + +// Backward compatibility: Create ConvDescription from Builder type +template + requires IsConvBuilder && (!HasInstanceTraits) +ConvDescription Describe() +{ + // Delegate to Instance-based version + using Instance = typename Builder::Instance; + return Describe(); +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index a74d77d155..86cf11f647 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index c863d2306c..e4d154ae10 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -13,7 +13,10 @@ #include #include #include -#include +#include +#include +#include +#include #include #include #include diff --git a/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp new file mode 100644 index 0000000000..6a80a994ee --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/tree_formatter.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include + +namespace ck_tile::reflect { + +// Helper class for formatting hierarchical tree structures with proper indentation +// and tree-drawing characters (├─, └─, │, etc.) +// +// Example Usage: +// +// TreeFormatter f; +// f.writeLine(0, "Root"); +// f.writeLine(1, "Branch 1"); +// f.writeLine(2, "Item 1a"); +// f.writeLast(2, "Item 1b"); +// f.writeLast(1, "Branch 2"); +// f.writeLast(2, "Item 2a"); +// std::cout << f.getString() << "\n"; +// +// Generated Output: +// +// Root +// ├─ Branch 1 +// │ ├─ Item 1a +// │ └─ Item 1b +// └─ Branch 2 +// └─ Item 2a +class TreeFormatter +{ + public: + TreeFormatter() = default; + + // Write a line at the specified indentation level (branch continues after this) + template + void writeLine(int indent_level, Args&&... args) + { + writeLineImpl(indent_level, false, std::forward(args)...); + } + + // Write the last line at the specified indentation level (branch ends) + template + void writeLast(int indent_level, Args&&... args) + { + writeLineImpl(indent_level, true, std::forward(args)...); + } + + // Get the formatted string (removes trailing newline if present) + std::string getString() const + { + std::string result = oss_.str(); + if(!result.empty() && result.back() == '\n') + { + result.pop_back(); + } + return result; + } + + private: + std::ostringstream oss_; + std::vector is_last_at_level_; // Tracks which levels have ended + + // Implementation of line writing with tree symbols + template + void writeLineImpl(int indent_level, bool is_last, Args&&... args) + { + // Ensure we have enough tracking space + if(static_cast(indent_level) >= is_last_at_level_.size()) + { + is_last_at_level_.resize(indent_level + 1, false); + // Level 0 (root) should always be treated as "last" since it has no tree symbols + if(is_last_at_level_.size() > 0) + { + is_last_at_level_[0] = true; + } + } + + // Draw the tree structure + // Start from level 1 (skip level 0 which is the root with no symbols) + for(int i = 1; i < indent_level; ++i) + { + // For all parent levels, draw vertical line or space based on whether they ended + oss_ << (is_last_at_level_[i] ? " " : "│ "); + } + + // Draw the branch symbol for the current level + if(indent_level > 0) + { + oss_ << (is_last ? "└─ " : "├─ "); + } + + // Write the content using fold expression with direct stream insertion + ((oss_ << std::forward(args)), ...); + + oss_ << '\n'; + + // Update tracking for this level AFTER writing the line + // This ensures future lines at deeper levels know if this level ended + is_last_at_level_[indent_level] = is_last; + } +}; + +} // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 2af10346e5..a58c994288 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + namespace ck_tile::builder { enum class DataType @@ -215,4 +219,275 @@ enum class PipelineScheduler INTERWAVE }; +// ostream operator overloads for enum classes +inline std::ostream& operator<<(std::ostream& os, DataType dt) +{ + using enum DataType; + switch(dt) + { + case FP16: return os << "FP16"; + case FP32: return os << "FP32"; + case BF16: return os << "BF16"; + case FP8: return os << "FP8"; + case I8: return os << "I8"; + case U8: return os << "U8"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) +{ + using enum ConvDirection; + switch(dir) + { + case FORWARD: return os << "Forward"; + case BACKWARD_DATA: return os << "Backward Data"; + case BACKWARD_WEIGHT: return os << "Backward Weight"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout) +{ + using enum GroupConvLayout1D; + switch(layout) + { + case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK"; + case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK"; + case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW"; + case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout) +{ + using enum GroupConvLayout2D; + switch(layout) + { + case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK"; + case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK"; + case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW"; + case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout) +{ + using enum GroupConvLayout3D; + switch(layout) + { + case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK"; + case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK"; + case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW"; + case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op) +{ + using enum FwdGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK: + return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"; + case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle: + return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"; + case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle: + return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"; + case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3: + return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"; + case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor: + return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op) +{ + using enum BwdDataGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD"; + case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle: + return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"; + case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1: + return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op) +{ + using enum BwdWeightGroupConvDeviceOperation; + switch(op) + { + case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight"; + case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl"; + case DeviceGroupedConvBwdWeight_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"; + case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3: + return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + case DeviceGroupedConvBwdWeight_Wmma_CShuffle: + return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; + case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"; + case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD"; + case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle: + return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op) +{ + using enum ElementwiseOperation; + switch(op) + { + case BIAS: return os << "BIAS"; + case BIAS_CLAMP: return os << "BIAS_CLAMP"; + case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP"; + case BILINEAR: return os << "BILINEAR"; + case CLAMP: return os << "CLAMP"; + case SCALE: return os << "SCALE"; + case PASS_THROUGH: return os << "PASS_THROUGH"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver) +{ + using enum PipelineVersion; + switch(ver) + { + case V1: return os << "V1"; + case V2: return os << "V2"; + case V3: return os << "V3"; + case V4: return os << "V4"; + case V5: return os << "V5"; + case WEIGHT_ONLY: return os << "WEIGHT_ONLY"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec) +{ + using enum GemmSpecialization; + switch(spec) + { + case Default: return os << "Default"; + case MPadding: return os << "MPadding"; + case NPadding: return os << "NPadding"; + case KPadding: return os << "KPadding"; + case MNPadding: return os << "MNPadding"; + case MKPadding: return os << "MKPadding"; + case NKPadding: return os << "NKPadding"; + case MNKPadding: return os << "MNKPadding"; + case OPadding: return os << "OPadding"; + case MOPadding: return os << "MOPadding"; + case NOPadding: return os << "NOPadding"; + case KOPadding: return os << "KOPadding"; + case MNOPadding: return os << "MNOPadding"; + case MKOPadding: return os << "MKOPadding"; + case NKOPadding: return os << "NKOPadding"; + case MNKOPadding: return os << "MNKOPadding"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec) +{ + using enum ConvFwdSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + case FILTER_3x3: return os << "FILTER_3x3"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec) +{ + using enum ConvBwdDataSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec) +{ + using enum ConvBwdWeightSpecialization; + switch(spec) + { + case DEFAULT: return os << "DEFAULT"; + case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0"; + case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0"; + case ODD_C: return os << "ODD_C"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, GemmPadding padding) +{ + using enum GemmPadding; + switch(padding) + { + case DEFAULT: return os << "DEFAULT"; + case M_PADDING: return os << "M_PADDING"; + case N_PADDING: return os << "N_PADDING"; + case K_PADDING: return os << "K_PADDING"; + case MN_PADDING: return os << "MN_PADDING"; + case MK_PADDING: return os << "MK_PADDING"; + case NK_PADDING: return os << "NK_PADDING"; + case MNK_PADDING: return os << "MNK_PADDING"; + case O_PADDING: return os << "O_PADDING"; + case MO_PADDING: return os << "MO_PADDING"; + case NO_PADDING: return os << "NO_PADDING"; + case KO_PADDING: return os << "KO_PADDING"; + case MNO_PADDING: return os << "MNO_PADDING"; + case MKO_PADDING: return os << "MKO_PADDING"; + case NKO_PADDING: return os << "NKO_PADDING"; + case MNKO_PADDING: return os << "MNKO_PADDING"; + default: return os << "Unknown"; + } +} + +inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched) +{ + using enum PipelineScheduler; + switch(sched) + { + case DEFAULT: return os << "DEFAULT"; + case INTRAWAVE: return os << "INTRAWAVE"; + case INTERWAVE: return os << "INTERWAVE"; + default: return os << "Unknown"; + } +} + +// ostream operator overload for std::variant of layout types +inline std::ostream& +operator<<(std::ostream& os, + const std::variant& layout) +{ + std::visit([&os](const auto& l) { os << l; }, layout); + return os; +} + +// ostream operator overload for std::variant of convolution specializations +inline std::ostream& operator<<(std::ostream& os, + const std::variant& spec) +{ + std::visit([&os](const auto& s) { os << s; }, spec); + return os; +} + } // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 0cb3237f8c..b776edbcde 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -67,6 +67,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test add_ck_builder_test(test_conv_traits conv/test_conv_traits.cpp) +add_ck_builder_test(test_conv_description + test_conv_description.cpp) + # Function to add all test_ckb targets to a list function(collect_test_ckb_targets result_var) # Get all targets in current directory diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp new file mode 100644 index 0000000000..97af4af795 --- /dev/null +++ b/experimental/builder/test/test_conv_description.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include "testing_utils.hpp" +#include "impl/conv_signature_types.hpp" +#include "impl/conv_algorithm_types.hpp" + +namespace { + +namespace ckb = ck_tile::builder; +namespace ckr = ck_tile::reflect::conv; +namespace ckt = ck_tile::test; + +// Defines the signature of the convolution operation to be tested. +// This includes dimensionality, direction, data layout, and data type. +struct ConvSignature +{ + int spatial_dim = 2; + ckb::ConvDirection direction = ckb::ConvDirection::FORWARD; + ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + ckb::DataType data_type = ckb::DataType::FP16; + ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH; + ckb::GroupConvDeviceOp device_operation = + ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; +}; +static_assert(ckb::ConvSignatureDescriptor); + +struct DefaultAlgorithm +{ + ckb::test::ThreadBlock thread_block{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; + + ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 16, + .n_per_xdl = 16, + .m_xdl_per_wave = 4, + .n_xdl_per_wave = 4}; + + ckb::test::BlockTransferABC block_transfer{ + .block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8}, + .block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = true, + .lds_padding = false}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {.order = {0, 1, 2}}, + .block_transfer_access_order_b = {.order = {0, 1, 2}}, + .src_access_order_a = {.order = {0, 1, 2}}, + .src_access_order_b = {.order = {0, 1, 2}}}; + + ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT; + ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; + ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, + .scheduler = ckb::PipelineScheduler::INTRAWAVE}; +}; +static_assert(ckb::ConvAlgorithmDescriptor); + +TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription) +{ + static constexpr const ConvSignature SIGNATURE; + static constexpr const DefaultAlgorithm ALGORITHM; + using Builder = ckb::ConvBuilder; + EXPECT_THAT(ckr::Describe().brief(), ckt::StringEqWithDiff("2D Forward convolution")); +} + +TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) +{ + static constexpr const ConvSignature SIGNATURE; + static constexpr const DefaultAlgorithm ALGORITHM; + using Builder = ckb::ConvBuilder; + EXPECT_THAT(ckr::Describe().detailed(), + ckt::StringEqWithDiff( // + "2D Forward Convolution Kernel\n" + "├─ Signature\n" + "│ ├─ Tensor Type: FP16\n" + "│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n" + "│ ├─ Input elementwise operation: PASS_THROUGH\n" + "│ ├─ Weights elementwise operation: PASS_THROUGH\n" + "│ └─ Output elementwise operation: PASS_THROUGH\n" + "├─ Algorithm\n" + "│ ├─ Thread block size: 256\n" + "│ ├─ Data tile size: 256×256×32\n" + "│ ├─ Gemm padding: DEFAULT\n" + "│ ├─ Convolution specialization: DEFAULT\n" + "│ ├─ Pipeline version: V4\n" + "│ ├─ Pipeline scheduler: INTRAWAVE\n" + "│ ├─ Warp Gemm parameters: \n" + "│ │ ├─ subtile size: 16×16\n" + "│ │ └─ Number of warp gemm iterations: 4×4\n" + "│ ├─ Memory access:\n" + "│ │ ├─ A Tile transfer: \n" + "│ │ │ ├─ Tile dimensions: 4×256×8×\n" + "│ │ │ ├─ The innermost K subdimension size: 8\n" + "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" + "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + "│ │ ├─ B Tile transfer: \n" + "│ │ │ ├─ Tile dimensions: 4×256×8×\n" + "│ │ │ ├─ The innermost K subdimension size: 8\n" + "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" + "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" + "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" + "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + "│ │ └─ C Tile transfer: \n" + "│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + "│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + "│ │ └─ Vector access (GMEM write) instruction size: 8\n" + "│ └─ \n" + "└─ ")); +} + +// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory +// does not have a specialization for backward data convolutions. The test fails with: +// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'" +// +// To enable this test, a ConvFactory specialization for backward data operations must be +// implemented first. +// +// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription) +// { +// struct BackwardDataSignature +// { +// int spatial_dim = 2; +// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA; +// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; +// ckb::DataType data_type = ckb::DataType::FP16; +// ckb::ElementwiseOperation elementwise_operation = +// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation = +// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; +// }; +// static_assert(ckb::ConvSignatureDescriptor); +// +// static constexpr const BackwardDataSignature SIGNATURE; +// static constexpr const DefaultAlgorithm ALGORITHM; +// using Builder = ckb::ConvBuilder; +// +// // Verify Brief works +// EXPECT_THAT(ckr::Describe().brief(), +// ckt::StringEqWithDiff("2D Backward Data convolution")); +// +// // Verify detailed works - to be updated once ConvFactory is implemented +// EXPECT_THAT(ckr::Describe().detailed(), +// ckt::StringEqWithDiff("PLACEHOLDER")); +// } +} // namespace From 76c4c12f5959adcd56d1627a1d1ce885deb9d096 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Fri, 7 Nov 2025 00:07:39 +0100 Subject: [PATCH 006/118] Add .clangd and CMakeUserPresets.json to .gitignore (#3171) --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index 6641e5bc58..2641a661d8 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,12 @@ docs/doxygen/xml cmake-build*/ build*/ +# LSP configuration +.clangd + +# User-defined CMake presets +CMakeUserPresets.json + # Python virtualenv .venv/ From 5f3cae3e28a042e411afcd2e54b16cc6909c5bbb Mon Sep 17 00:00:00 2001 From: JH-Leon-KIM-AMD Date: Fri, 7 Nov 2025 02:29:48 +0200 Subject: [PATCH 007/118] [CK_BUILDER]ckb add remining fwd conv device ops (#3155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add device operation to conv signature. Use unions to hold conv layouts and device operations. * Add predicates for all device op instances. * Use the device op signature for validation. * Fix ckb CMakeLists.txt file for tests. * Fix building CK Builder instance traits after the introduction of direct load template parameter in CK. * Fix clang-formatting. * add device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk * Add full DL configurability with Option A implementation - Added 5 DL descriptor structs (39 configurable parameters) - Added 10 C++20 concepts for type-safe validation - Updated factory to read all parameters from descriptors - Updated test helper to populate all descriptors - All tests passing (13/13 including 3 new DL tests) * Add factory and test support for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add factory specialization for Large_Tensor device operation (conv_factory.hpp lines 1145-1265) - Add macro collision workaround using pragma push/pop (conv_factory.hpp lines 43-51) - Add test helper function run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor - Add builder test file test_ckb_conv_fwd_2d_large_tensor_fp16.cpp with 2 test cases - Update CMakeLists.txt to include new test file - Reuse existing ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle descriptor - Map all 42 template parameters identical to regular XDL CShuffle - All 15 builder tests passing including 2 new Large_Tensor tests Completes Task 350: All 4 forward convolution device operations now supported in CK Builder. * Update copyright headers to new format - Change copyright format to: Copyright (C) Advanced Micro Devices, Inc., or its affiliates. - Reorder headers: Copyright first, then SPDX-License-Identifier - Updated files: * experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp * experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp * experimental/builder/include/ck_tile/builder/device_op_types.hpp * fix c++ 18 format * Fix clang-format-18 error in device_op_types.hpp --------- Co-authored-by: Ville Pietilä Co-authored-by: Ville Pietilä <188998872+vpietila-amd@users.noreply.github.com> --- .../builder/conv_algorithm_concepts.hpp | 83 ++++++ .../include/ck_tile/builder/conv_factory.hpp | 271 ++++++++++++++++++ .../ck_tile/builder/device_op_types.hpp | 22 ++ experimental/builder/test/CMakeLists.txt | 2 + .../conv/test_ckb_conv_fwd_2d_dl_fp16.cpp | 69 +++++ ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 53 ++++ .../test/impl/conv_algorithm_types.hpp | 80 ++++++ .../test/utils/ckb_conv_test_common.hpp | 145 ++++++++++ 8 files changed, 725 insertions(+) create mode 100644 experimental/builder/include/ck_tile/builder/device_op_types.hpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp create mode 100644 experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index e43f910a73..6006efe4f8 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -183,4 +183,87 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; +/******************************************** */ +/* DL-specific descriptors and requirements */ +/******************************************** */ + +// Concept for DL thread configuration +template +concept DlThreadConfigDescriptor = requires(T t) { + { t.k0_per_block } -> std::convertible_to; + { t.k1 } -> std::convertible_to; + { t.m1_per_thread } -> std::convertible_to; + { t.n1_per_thread } -> std::convertible_to; + { t.k_per_thread } -> std::convertible_to; +}; + +// Concept for DL thread cluster +template +concept DlThreadClusterDescriptor = requires(T t) { + { t.m1_xs } -> std::convertible_to>; + { t.n1_xs } -> std::convertible_to>; +}; + +// Concept for DL block transfer K0_M0_M1_K1 format +template +concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) { + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; +}; + +// Concept for DL block transfer K0_N0_N1_K1 format +template +concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) { + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; +}; + +// Concept for DL C thread transfer +template +concept DlCThreadTransferDescriptor = requires(T t) { + { t.src_dst_access_order } -> std::convertible_to>; + { t.src_dst_vector_dim } -> std::convertible_to; + { t.dst_scalar_per_vector } -> std::convertible_to; +}; + +// Concept to check if algorithm specifies DL thread config +template +concept SpecifiesDlThreadConfig = requires { + { T::dl_thread_config } -> DlThreadConfigDescriptor; +}; + +// Concept to check if algorithm specifies DL thread cluster +template +concept SpecifiesDlThreadCluster = requires { + { T::dl_thread_cluster } -> DlThreadClusterDescriptor; +}; + +// Concept to check if algorithm specifies DL A block transfer +template +concept SpecifiesDlBlockTransferA = requires { + { T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor; +}; + +// Concept to check if algorithm specifies DL B block transfer +template +concept SpecifiesDlBlockTransferB = requires { + { T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor; +}; + +// Concept to check if algorithm specifies DL C thread transfer +template +concept SpecifiesDlCThreadTransfer = requires { + { T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 1ccc190ba2..e40199987d 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -36,9 +36,21 @@ #pragma once +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +// WORKAROUND: Macro namespace collision in upstream CK device operation headers. +// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and +// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define +// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors. +// Use pragma push/pop to isolate the Large_Tensor header's macro scope. +#pragma push_macro("GridwiseGemmTemplateParameters") +#ifdef GridwiseGemmTemplateParameters +#undef GridwiseGemmTemplateParameters +#endif +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#pragma pop_macro("GridwiseGemmTemplateParameters") #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" @@ -990,4 +1002,263 @@ struct ConvFactory GRIDWISE_GEMM_PIPELINE_VERSION>; }; +// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance +// of a grouped forward convolution kernel using Direct Load (DL) approach. +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(SpecifiesDlThreadConfig, + "DL algorithm must specify thread config."); + static_assert(SpecifiesDlThreadCluster, + "DL algorithm must specify thread cluster."); + static_assert(SpecifiesDlBlockTransferA, + "DL algorithm must specify A block transfer."); + static_assert(SpecifiesDlBlockTransferB, + "DL algorithm must specify B block transfer."); + static_assert(SpecifiesDlCThreadTransfer, + "DL algorithm must specify C thread transfer."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + + // DL-specific parameters from algorithm descriptor + static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config; + static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; + static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; + static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; + static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread; + static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; + + // Thread cluster from descriptor + static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster; + using M1N1ThreadClusterM1Xs = to_sequence_v; + using M1N1ThreadClusterN1Xs = to_sequence_v; + + // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format + static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a; + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using ABlockTransferSrcAccessOrder = to_sequence_v; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + + // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format + static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b; + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using BBlockTransferSrcAccessOrder = to_sequence_v; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + + // C Thread Transfer from descriptor + static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer; + using CThreadTransferSrcDstAccessOrder = to_sequence_v; + static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; + static constexpr ck::index_t CThreadTransferDstScalarPerVector = + DL_C_TRANSFER.dst_scalar_per_vector; + + // The DL forward convolution kernel class instance + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + SPATIAL_DIM, + typename Types::ADataType, + typename Types::BDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Types::AccDataType, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + FWD_CONV_SPECIALIZATION, + GEMM_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + K0PerBlock, + K1, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM1Xs, + M1N1ThreadClusterN1Xs, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; +}; + +// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance +// of a grouped forward convolution kernel with large tensor support (N-splitting). +template + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +struct ConvFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(factory_internal::GetTensorLayout()); + using Types = factory_internal::ConvTensorTypes; + using Ops = factory_internal::ElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(SpecifiesThreadBlock, + "The convolution algorithm descriptor must specify thread block info."); + static_assert(SpecifiesGridwiseXdlGemm, + "The convolution algorithm descriptor must specify gridwise GEMM info."); + static_assert(SpecifiesBlockTransfer, + "The convolution algorithm descriptor must specify block transfer info."); + static_assert(SpecifiesLdsTransfer, + "The convolution algorithm descriptor must specify LDS transfer info."); + static_assert( + SpecifiesThreadClusterAccessOrder, + "The convolution algorithm descriptor must specify thread cluster access order info."); + static_assert(SpecifiesSourceAccessOrder, + "The convolution algorithm descriptor must specify source access order info."); + static_assert(SpecifiesFwdConcSpecialization, + "The convolution algorithm descriptor must specify forward convolution " + "specialization."); + static_assert(SpecifiesGemmSpecialization, + "The convolution algorithm descriptor must specify gemm specialization."); + static_assert(SpecifiesNumPrefetchStages, + "The convolution algorithm descriptor must specify number of prefetch stages."); + static_assert(SpecifiesLoopScheduler, + "The convolution algorithm descriptor must specify loop scheduler."); + + static constexpr auto FWD_CONV_SPECIALIZATION = + factory_internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = + factory_internal::SetGemmSpecialization(); + static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); + static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + factory_internal::SetFwdConvABlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + factory_internal::SetFwdConvBBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = + factory_internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance with large tensor support. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::AComputeType, + typename Types::BComputeType, + LOOP_SCHEDULER>; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/device_op_types.hpp b/experimental/builder/include/ck_tile/builder/device_op_types.hpp new file mode 100644 index 0000000000..0e779fdf4e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/device_op_types.hpp @@ -0,0 +1,22 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile::builder { + +// Enumeration for CK Device Operation types. +// This allows the builder to select which device operation template to instantiate +// based on the user's requirements. +enum class DeviceOpType +{ + // Forward Convolution - Non-grouped + CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet) + + // Forward Convolution - Grouped + GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to: + // DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +}; + +} // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index b776edbcde..43c4fd4857 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -43,6 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_2d_bf16.cpp conv/test_ckb_conv_fwd_2d_fp16.cpp conv/test_ckb_conv_fwd_2d_fp32.cpp + conv/test_ckb_conv_fwd_2d_dl_fp16.cpp + conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp conv/test_ckb_conv_fwd_3d_bf16.cpp conv/test_ckb_conv_fwd_3d_fp16.cpp conv/test_ckb_conv_fwd_3d_fp32.cpp) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp new file mode 100644 index 0000000000..12730bab19 --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -0,0 +1,69 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); +} + +TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK(); +} + +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; + + run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp new file mode 100644 index 0000000000..0216c5907d --- /dev/null +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -0,0 +1,53 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_common.hpp" + +using namespace ck_tile::builder::test_utils; + +namespace ck_tile::builder::testing { + +TEST(FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; + + run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::DEFAULT>(); +} + +TEST( + FwdConvInstances, + Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0) +{ + constexpr ConvSignature FwdConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, + .data_type = DataType::FP16, + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor}; + + constexpr ThreadBlock FwdThreadBlock{.block_size = 128, + .tile_size = {.m = 128, .n = 128, .k = 32}}; + + run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + FwdConvSignature, + FwdThreadBlock, + ConvFwdSpecialization::FILTER_1X1_PAD0>(); +} + +} // namespace ck_tile::builder::testing diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index accc4048dc..88c5b5787a 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -214,4 +214,84 @@ static_assert( static_assert( ckb::SpecifiesLoopScheduler); +// DL-specific descriptors +struct DlThreadConfig +{ + size_t k0_per_block; + size_t k1; + size_t m1_per_thread; + size_t n1_per_thread; + size_t k_per_thread; +}; +static_assert(ckb::DlThreadConfigDescriptor); + +struct DlThreadCluster +{ + std::array m1_xs; // e.g., {8, 2} + std::array n1_xs; // e.g., {8, 2} +}; +static_assert(ckb::DlThreadClusterDescriptor); + +struct DlBlockTransferK0M0M1K1 +{ + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; +}; +static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor); + +struct DlBlockTransferK0N0N1K1 +{ + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; +}; +static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor); + +struct DlCThreadTransfer +{ + std::array src_dst_access_order; + size_t src_dst_vector_dim; + size_t dst_scalar_per_vector; +}; +static_assert(ckb::DlCThreadTransferDescriptor); + +struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK +{ + ThreadBlock thread_block; + ConvFwdSpecialization fwd_specialization; + GemmSpecialization gemm_specialization; + DlThreadConfig dl_thread_config; + DlThreadCluster dl_thread_cluster; + DlBlockTransferK0M0M1K1 dl_block_transfer_a; + DlBlockTransferK0N0N1K1 dl_block_transfer_b; + DlCThreadTransfer dl_c_thread_transfer; +}; +static_assert( + ckb::ConvAlgorithmDescriptor); +static_assert( + ckb::SpecifiesThreadBlock); +static_assert(ckb::SpecifiesFwdConcSpecialization< + ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>); +static_assert( + ckb::SpecifiesGemmSpecialization); +static_assert( + ckb::SpecifiesDlThreadConfig); +static_assert( + ckb::SpecifiesDlThreadCluster); +static_assert( + ckb::SpecifiesDlBlockTransferA); +static_assert( + ckb::SpecifiesDlBlockTransferB); +static_assert( + ckb::SpecifiesDlCThreadTransfer); + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index 7fd02a56f7..14fae566f6 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -235,4 +235,149 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() EXPECT_NE(invoker_ptr, nullptr); } +template +constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK() +{ + // DL thread configuration + constexpr DlThreadConfig DlThreadCfg{ + .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; + + // DL thread cluster + constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; + + // DL A block transfer - K0_M0_M1_K1 format + constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + + // DL B block transfer - K0_N0_N1_K1 format + constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; + + // DL C thread transfer + constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}; + + constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .dl_thread_config = DlThreadCfg, + .dl_thread_cluster = DlCluster, + .dl_block_transfer_a = DlBlockTransferA, + .dl_block_transfer_b = DlBlockTransferB, + .dl_c_thread_transfer = DlCTransfer}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + +// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +// Note: Large_Tensor has identical parameters to regular XDL CShuffle +template +constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor() +{ + constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8, + .bk1 = 8, + .m_per_xdl = 32, + .n_per_xdl = 32, + .m_xdl_per_wave = 2, + .n_xdl_per_wave = 1}; + + constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 16, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; + + // Large_Tensor uses the same descriptor as regular XDL CShuffle + constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{ + .thread_block = FwdThreadBlock, + .gridwise_gemm = FwdGemmParams, + .block_transfer = FwdBlockTransfer, + .fwd_specialization = FwdConvSpecialization, + .gemm_specialization = GemmSpecialization::MNKPadding, + .num_gemm_k_prefetch_stages = 1, + .num_groups_to_merge = 1, + .loop_scheduler = LoopScheduler::DEFAULT}; + + using Builder = ConvBuilder; + + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + + EXPECT_TRUE( + kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor")); + + // Verify specialization is correct + if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT) + EXPECT_TRUE(kernel_string.find("Default") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) + EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos); + else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3) + EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos); + + const auto invoker_ptr = instance.MakeInvokerPointer(); + EXPECT_NE(invoker_ptr, nullptr); +} + } // namespace ck_tile::builder::test_utils From d04eba4ae37c8c2d40855f02aa861e1ac1ec7b3f Mon Sep 17 00:00:00 2001 From: Xudong Yuan Date: Fri, 7 Nov 2025 08:45:41 +0800 Subject: [PATCH 008/118] Ck moe mxfp4 blockm32 (#3098) * block_m = 32 * ck block_m = 32 * aiter/3rdparty/composable_kernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp format * mxfp4_moe v1 pipe * update format --------- Co-authored-by: zhimding Co-authored-by: lalala-sh Co-authored-by: felix --- .../moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp | 12 +- ...xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp | 3 +- ...ne_xdlops_b_preshuffle_mx_moe_selector.hpp | 24 +- ...pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp | 891 ++++++++++++++++++ ...pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp | 234 +++-- .../impl/device_moe_mx_gemm_bpreshuffle.hpp | 2 +- .../grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 468 +++++---- 7 files changed, 1357 insertions(+), 277 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp index 1adf039b70..ebb73ca7e0 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -181,7 +181,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 32; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -190,10 +190,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ScaleBlockSize, 256, - MPerBlock, 64, KPerBlock, + MPerBlock, 128, KPerBlock, 16, 16, 16, 16, - 4, 2, + 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, @@ -213,10 +213,10 @@ int main(int argc, char* argv[]) ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t N = 6144; - ck::index_t K = 4096; + ck::index_t N = 7168; + ck::index_t K = 256; ck::index_t experts = 8; - ck::index_t tokens = 832; + ck::index_t tokens = 208; ck::index_t topk = 2; if(argc == 1) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp index b3b3d312c7..b621c3a93d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp @@ -727,7 +727,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< }); }); - HotLoopScheduler(); + if constexpr(MPerBlock >= 64) + HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp index 6789d26a45..5223993671 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp" namespace ck { @@ -45,7 +46,28 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector() } else { - return nullptr; + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; } } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp new file mode 100644 index 0000000000..fc5cb60c37 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp @@ -0,0 +1,891 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::A_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + static constexpr auto async_vmcnt = + num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num; + static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves + + num_buffer_load_a_scale + num_buffer_load_b_scale; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 128) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + ignore = b_block_bufs; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc, + b_block_origin_idx, + b_thread_bufs(scale_mem_buf)); + + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [scale_comp_buf][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1)); + + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + // constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; + }); + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp index 2b936c8d25..7473d2f2e7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3 2) + { - // Group num_mfma_perstage num_ds_read_a_perstage - // since we want to reuse a local register buffer - constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; - constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; - constexpr auto num_ds_read_a_mfma_perstage = - math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); - constexpr auto num_ds_read_a_prefetch_stages = 2; + constexpr auto num_ds_read_a_prefetch_stages = 2; - constexpr auto buffer_load_perstage_more = - math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_less = - math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); - constexpr auto buffer_load_perstage_stage2 = - math::integer_divide_floor((num_buffer_load_stage2), 2); + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_stage2 = + math::integer_divide_floor((num_buffer_load_stage2), 2); - constexpr auto buffer_load_stages_more = - num_buffer_load_stage1 - - math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * - ((num_total_stages - 2)); + constexpr auto buffer_load_stages_more = + num_buffer_load_stage1 - + math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * + ((num_total_stages - 2)); - constexpr auto buffer_load_issue_point_interval_more = - num_mfma_perstage / buffer_load_perstage_more; - constexpr auto buffer_load_issue_point_interval_less = - num_mfma_perstage / buffer_load_perstage_less; - constexpr auto buffer_load_issue_point_interval_stage2 = - num_mfma_perstage / buffer_load_perstage_stage2; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto buffer_load_issue_point_interval_stage2 = + num_mfma_perstage / buffer_load_perstage_stage2; - // Stage 1 - // global read more - static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // Stage 1 + // global read more + static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // global read less + static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // Stage 2, Sync + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + else + { + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + + num_buffer_load_b_scale; + constexpr auto num_dsread_a_mfma = math::integer_divide_ceil( + num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) { + static_for<0, mfma_perstage_more, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < + mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } - }); - }); - - // global read less - static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + else { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read } }); - }); - - // Stage 2, Sync - // lds synchronization, prefetch next loop local A - static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { - static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) - { - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - }); - }); + } } template ()) - { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); - } + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1249,7 +1246,6 @@ struct GridwiseMoeGemmMX_BPreshuffle __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; - return BlockwiseGemmPipe::BlockHasHotloop(num_loop); } @@ -1279,7 +1275,6 @@ struct GridwiseMoeGemmMX_BPreshuffle // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, // NPerBlock>; -#if 0 template @@ -1298,9 +1293,10 @@ struct GridwiseMoeGemmMX_BPreshuffle BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + ignore = a_element_op; ignore = b_element_op; - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1317,29 +1313,41 @@ struct GridwiseMoeGemmMX_BPreshuffle problem.NPadded, problem.StrideC); - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( - make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock), + // We pad the M unconditionaly for Scale + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / (KXdlPack * 64 / MPerXdl), - 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( make_tuple(problem.N / (NXdlPack * NPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / (KXdlPack * 64 / NPerXdl), - 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); - // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; if(expert_block_id * MPerBlock >= max_token_id) return; const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); - const auto block_mn = [&]() -> std::pair { if constexpr(NSwizzle) { @@ -1372,86 +1380,78 @@ struct GridwiseMoeGemmMX_BPreshuffle constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); constexpr auto AKThreads = AK0Threads * AK1Threads; constexpr auto AMRepeats = MPerBlock / AMThreads; - const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads; if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads]; index_t token_offset = fused_token & 0xffffff; if constexpr(!IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = static_cast(token_offset) * problem.K / APackedSize; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); + const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); - const index_t expert_scale_stride = - __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) * - math::integer_divide_ceil(problem.K, ScaleBlockSize)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack); + // Gride buffer creation const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + expert_id * expert_scale_stride, + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); // B matrix in LDS memory, dst of blockwise copy - // dummy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy - auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + + // A matrix blockwise direct to LDS copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad< ThisThreadBlock, - AElementwiseOperation, - ck::tensor_operation::element_wise::PassThrough, - InMemoryDataOperationEnum::Set, Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ADataType, - LDSTypeA, + ADataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, ABlockTransferSrcVectorDim, 2, ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, IndexType, - 1, - BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); + 1>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + gather_offsets); // Thread-wise copy // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack - auto b_block_buf = make_static_buffer( + auto b_block_buf_ping = make_static_buffer( b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2{}, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, + Sequence<0, 1, 2, 3, 4>, 4, BBlockTransferSrcScalarPerVector, BThreadTransferSrcResetCoordinateAfterRun, @@ -1472,16 +1472,16 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); // Blockwise GEMM pipeline static_assert(std::is_default_constructible_v); @@ -1505,13 +1505,16 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto waveId_m = wave_idx[I0]; const auto waveId_n = wave_idx[I1]; - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - auto thread_offset_shuffled = get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; auto a_thread_offset_m = waveId_m; + // get each thread's offset int the scale tensor + const index_t token_scale_pos = block_m_id * MPerBlock; + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< AScaleDataType, AScaleDataType, @@ -1538,7 +1541,7 @@ struct GridwiseMoeGemmMX_BPreshuffle Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths Sequence<0, 1, 2>, // DimAccessOrder 2, // SrcVectorDim - KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector 1, // SrcScalarStrideInVector true>(b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, @@ -1547,29 +1550,37 @@ struct GridwiseMoeGemmMX_BPreshuffle if constexpr(IsInputGemm) { - const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * expert_stride / BPackedSize, + p_b_grid_up + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< - BDataType, - BDataType, - decltype(b_grid_desc_bpreshuffled), - decltype(b_block_desc_bk0_n_bk1), - Sequence{}, I1, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, - 3, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_grid_desc_bpreshuffled, - make_multi_index(n_block_data_idx_on_grid, - get_warp_local_1d_id() % NWave, - 0, - KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); - const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; - const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride, + auto b_blockwise_copy_up = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<0, 1, 2, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); + const BScaleDataType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< BScaleDataType, BScaleDataType, @@ -1587,25 +1598,30 @@ struct GridwiseMoeGemmMX_BPreshuffle thread_offset_shuffled / scale_pack_size_b)); blockwise_gemm_pipeline.template Run( + // A a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, + // Gate and Up b_grid_desc_bpreshuffled, b_block_desc_bk0_n_bk1, b_blockwise_copy, b_blockwise_copy_up, b_grid_buf, b_grid_buf_up, - b_block_buf, + b_block_bufs, b_block_slice_copy_step, + // C c_thread_buf, c_thread_buf_up, + // A scale a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf, + // B scale b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_thread_copy_up, @@ -1616,23 +1632,23 @@ struct GridwiseMoeGemmMX_BPreshuffle else { blockwise_gemm_pipeline.template Run( - a_grid_desc_ak0_m_ak1, + a_grid_desc_ak0_m_ak1, // A a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_bpreshuffled, + b_grid_desc_bpreshuffled, // B b_block_desc_bk0_n_bk1, b_blockwise_copy, b_grid_buf, - b_block_buf, + b_block_bufs, b_block_slice_copy_step, - c_thread_buf, - a_scale_grid_desc_am_ak, + c_thread_buf, // C + a_scale_grid_desc_am_ak, // A scale a_scale_thread_copy, a_scale_grid_buf, - b_scale_grid_desc_bn_ak, + b_scale_grid_desc_bn_ak, // B scale b_scale_thread_copy, b_scale_grid_buf, num_k_block_main_loop); @@ -1643,84 +1659,101 @@ struct GridwiseMoeGemmMX_BPreshuffle static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); // mul scales - static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4); + + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); + static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; - const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; vector_type topk_weights; // for gemm2 only - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(MulRoutedWeight) - { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - - if constexpr(IsInputGemm) // gu fusion - { - if constexpr(ActivationOperation == Activation::silu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - else if(ActivationOperation == Activation::gelu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - } - else - { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack + static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + + m0 * M2 * M1 * M3 * M4 * M5 + + m1 * M2 * M3 * M4 * M5 + + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = - topk_weights.AsType()[m4] * c_thread_buf_fp32[cidx]; + topk_weights = + *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); } - } + static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == + Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m5] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); }); }); }); @@ -1738,19 +1771,25 @@ struct GridwiseMoeGemmMX_BPreshuffle make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl M3, - M4)), + M4, + M5)), make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -1762,8 +1801,8 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), make_tuple(Sequence<0>{})); const auto m_thread_data_on_block_idx = @@ -1772,8 +1811,8 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0>{})); const auto n_thread_data_on_block_idx = @@ -1781,36 +1820,39 @@ struct GridwiseMoeGemmMX_BPreshuffle make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; using EDataType = CDataType; @@ -1859,7 +1901,7 @@ struct GridwiseMoeGemmMX_BPreshuffle using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 1; // hack fix felix + constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1867,8 +1909,9 @@ struct GridwiseMoeGemmMX_BPreshuffle decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type Sequence<1, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1, @@ -1898,13 +1941,25 @@ struct GridwiseMoeGemmMX_BPreshuffle auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence Date: Fri, 7 Nov 2025 11:42:39 +0800 Subject: [PATCH 009/118] fix MX bpreshuffle gemm B grid descriptor dimension error. (#3170) --- .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 3d2ef9b6c4..7c5bd606b2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -429,8 +429,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t WaveSize = BlockSize / (MWave * NWave); constexpr index_t NkSwizzleNumber = Number{}; - return make_naive_tensor_descriptor_packed( - make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); + return make_naive_tensor_descriptor_packed(make_tuple( + math::integer_divide_ceil(N0, NWave * NXdlPack), NWave, NXdlPack, K0, NkSwizzleNumber)); } __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( From d5746dd120c5d5ed9fd4558af0f189ec6308a155 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Tue, 11 Nov 2025 00:12:23 +0530 Subject: [PATCH 010/118] [CK-Tile] Add gtests for compiler CI for faster testing (#3123) * Add gtests for compiler CI for faster testing * Add changes to have a custom target * Add a gtest suite for gemm kernel for running CI tests with compiler mode * Fix Clang error (EOL) * Removed compiler subfolder from CMake * Add gtest suite for gemm kernel * Disable failed tests * Fix build errors * Resolved PR comments * Update shape for persistent gemm kernel test * Seperated types by H/W archs * Made changes to persistent types * Fix persistent build failure issue --------- Co-authored-by: Thomas Ning --- test/ck_tile/gemm/CMakeLists.txt | 6 + .../gemm/test_gemm_pipeline_compiler.cpp | 900 ++++++++++++++++++ 2 files changed, 906 insertions(+) create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 96c071cbc4..c08ab33b91 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -22,6 +22,12 @@ else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") + add_gtest_executable(test_gemm_pipeline_compiler test_gemm_pipeline_compiler.cpp) + target_compile_options(test_gemm_pipeline_compiler PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() + if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp new file mode 100644 index 0000000000..bf39e0b552 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_compiler.cpp @@ -0,0 +1,900 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +// ============================================================================ +// Comprehensive GEMM Compiler Validation Test Suite +// This file consolidates all GEMM pipeline tests for compiler validation +// Covers essential combinations of data types, layouts, and pipeline types +// ============================================================================ + +// ---------------------------------------------------------------------------- +// Test Class Definitions for Different Pipeline Types +// ---------------------------------------------------------------------------- + +template +class TestGemmMem : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmMemWmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +template +class TestGemmCompV3 : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmCompV3Wmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +template +class TestGemmCompV4 : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmCompV4Wmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +template +class TestGemmCompV6 : public TestCkTileGemmPipeline> +{ +}; + +template +class TestGemmPersistent : public TestCkTileGemmPipeline> +{ +}; + +#if defined(CK_TILE_USE_WMMA) +template +class TestGemmPersistentWmma : public TestCkTileGemmPipeline> +{ +}; +#endif + +// ---------------------------------------------------------------------------- +// Type Definitions for Each Pipeline Configuration +// ---------------------------------------------------------------------------- + +// Memory Pipeline Types +using MemTestTypes = ::testing::Types< + // Parameters: ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, + // M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, K_TileSize, Scheduler, + // PipelineType + + std::tuple, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// Memory Pipeline WMMA Types +using MemWmmaTestTypes = ::testing::Types< + std::tuple, + std::tuple>; +#endif + +// CompV3 Pipeline Types +using CompV3TestTypes = ::testing::Types< + std::tuple, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// CompV3 Pipeline WMMA Types +using CompV3WmmaTestTypes = ::testing::Types< + std::tuple, + std::tuple>; +#endif + +// CompV4 Pipeline Types +using CompV4TestTypes = ::testing::Types< + std::tuple, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// CompV4 Pipeline WMMA Types +using CompV4WmmaTestTypes = ::testing::Types< + std::tuple, + std::tuple>; +#endif + +// CompV6 Pipeline Types +using CompV6TestTypes = ::testing::Types< + std::tuple, + std::tuple>; + +// Persistent CompV3 Pipeline Types +using PersistentTestTypes = ::testing::Types, + std::tuple>; + +#if defined(CK_TILE_USE_WMMA) +// Persistent CompV3 Pipeline WMMA Types +using PersistentWmmaTestTypes = ::testing::Types, + std::tuple>; +#endif + +// ---------------------------------------------------------------------------- +// Test Suite Registrations +// ---------------------------------------------------------------------------- + +TYPED_TEST_SUITE(TestGemmMem, MemTestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmMemWmma, MemWmmaTestTypes); +#endif +TYPED_TEST_SUITE(TestGemmCompV3, CompV3TestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmCompV3Wmma, CompV3WmmaTestTypes); +#endif +TYPED_TEST_SUITE(TestGemmCompV4, CompV4TestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmCompV4Wmma, CompV4WmmaTestTypes); +#endif +TYPED_TEST_SUITE(TestGemmCompV6, CompV6TestTypes); +TYPED_TEST_SUITE(TestGemmPersistent, PersistentTestTypes); +#if defined(CK_TILE_USE_WMMA) +TYPED_TEST_SUITE(TestGemmPersistentWmma, PersistentWmmaTestTypes); +#endif + +// ============================================================================ +// Memory Pipeline Tests (Mem) +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmMem + +TYPED_TEST(TEST_SUITE_NAME, SmallM_SingleRow) +{ + std::vector Ms{1}; + constexpr int N = 1024; + constexpr int K = TestFixture::K_Tile * 2; + + for(int M : Ms) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, ExactlyTwoTiles_M) +{ + this->Run(TestFixture::M_Tile * 2, TestFixture::N_Tile, TestFixture::K_Tile * 2); +} + +TYPED_TEST(TEST_SUITE_NAME, ExactlyTwoTiles_N) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile * 2, TestFixture::K_Tile * 2); +} + +TYPED_TEST(TEST_SUITE_NAME, ExactlyTwoTiles_K) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile * 2); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_512x1024x512) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, Square_1024x1024x1024) +{ + constexpr int M = 1024; + constexpr int N = 1024; + constexpr int K = 1024; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_2048x2048x2048) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, VeryLargeMatrix_4096x4096x4096) +{ + constexpr int M = 4096; + constexpr int N = 4096; + constexpr int K = 4096; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, TallSkinny_4096x128x1024) +{ + constexpr int M = 4096; + constexpr int N = 128; + constexpr int K = 1024; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, ShortWide_128x4096x1024) +{ + constexpr int M = 128; + constexpr int N = 4096; + constexpr int K = 1024; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, DeepNarrow_2048x2048x8192) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 8192; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StressTest_ExtremelyTallMatrix) +{ + constexpr int M = 16384; + constexpr int N = 64; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StressTest_ExtremelyWideMatrix) +{ + constexpr int M = 64; + constexpr int N = 16384; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StressTest_VeryDeepK) +{ + constexpr int M = 1024; + constexpr int N = 1024; + constexpr int K = 16384; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Memory Pipeline Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmMemWmma + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_WMMA) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_WMMA) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_WMMA) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA + +// ============================================================================ +// Compute V3 Pipeline Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV3 + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV3) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV3) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, MidLargeM_CompV3) +{ + std::vector Ms{127, 255}; + constexpr int N = 1024; + + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + constexpr int VecLoadSize = (std::is_same_v || + std::is_same_v || + std::is_same_v) + ? 16 + : 8; + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + if(M % VecLoadSize == 0) + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV3) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV3) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, BatchedSmall_CompV3) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Compute V3 Pipeline Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV3Wmma + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV3Wmma) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV3Wmma) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV3Wmma) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV3Wmma) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA + +// ============================================================================ +// Compute V4 Pipeline Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV4 + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV4) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV4) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV4) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV4) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Compute V4 Pipeline Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV4Wmma + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV4Wmma) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV4Wmma) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV4Wmma) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA + +// ============================================================================ +// Compute V6 Pipeline Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmCompV6 + +TYPED_TEST(TEST_SUITE_NAME, SmallM_CompV6) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_CompV6) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, MidLargeM_CompV6) +{ + std::vector Ms{127, 255}; + constexpr int N = 1024; + + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + constexpr int VecLoadSize = (std::is_same_v || + std::is_same_v || + std::is_same_v) + ? 16 + : 8; + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + if(M % VecLoadSize == 0) + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_CompV6) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_CompV6) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +// ============================================================================ +// Persistent Kernel Tests +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmPersistent + +TYPED_TEST(TEST_SUITE_NAME, SmallM_Persistent) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_Persistent) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_Persistent) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_Persistent) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME + +#if defined(CK_TILE_USE_WMMA) +// ============================================================================ +// Persistent Kernel Tests with WMMA +// ============================================================================ + +#define TEST_SUITE_NAME TestGemmPersistentWmma + +TYPED_TEST(TEST_SUITE_NAME, SmallM_PersistentWmma) +{ + std::vector Ms{1, 2}; + constexpr int N = 1024; + std::vector Ks; + for(auto K_count : {2, 4}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } + + for(int M : Ms) + { + for(int K : Ks) + { + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } + } + } +} + +TYPED_TEST(TEST_SUITE_NAME, SingleTile_PersistentWmma) +{ + this->Run(TestFixture::M_Tile, TestFixture::N_Tile, TestFixture::K_Tile); +} + +TYPED_TEST(TEST_SUITE_NAME, Regular_PersistentWmma) +{ + constexpr int M = 512; + constexpr int N = 1024; + constexpr int K = 512; + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, LargeMatrix_PersistentWmma) +{ + constexpr int M = 2048; + constexpr int N = 2048; + constexpr int K = 2048; + this->Run(M, N, K); +} + +#undef TEST_SUITE_NAME +#endif // CK_TILE_USE_WMMA From e593a14ae1677d7aed696589e8740796bc6085c1 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Tue, 11 Nov 2025 02:58:08 +0800 Subject: [PATCH 011/118] [ck] correct memory size in grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 (#3168) b1 and b0 use same layout, so, the size of b1_tensors_device should be same with b0_tensors_device's --- .../grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 63343df3a8..6f30bdaa73 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -221,8 +221,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b0_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); - b1_tensors_device.emplace_back( - std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + b1_tensors_device.emplace_back(std::make_unique( + sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i])); d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); From 7b6ba8d5c2dc7663e15bd8811c18b4c51cf94c99 Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Tue, 11 Nov 2025 02:58:20 +0800 Subject: [PATCH 012/118] [ck] Enable missing op for gfx11 and gfx12 (#3187) --- profiler/src/CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 9f86f6d88f..c22867fbed 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -40,6 +40,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp) list(APPEND PROFILER_OPS profile_contraction_scale.cpp) endif() +endif() + +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_reduce.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) @@ -53,7 +56,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) endif() - if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") + if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp) @@ -74,7 +77,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp) - endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR From 9f33b7cfd3df3fcfd540f7633b0abd7019935761 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 10 Nov 2025 11:08:41 -0800 Subject: [PATCH 013/118] fix input range (#3188) --- example/ck_tile/03_gemm/run_gemm_example.inc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index d5f164c40f..703ab810d8 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -309,8 +309,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-2.f, 2.f}(a_m_k); + ck_tile::FillUniformDistribution{-2.f, 2.f}(b_k_n); } else if(init_method == 1) { From 1c544abf57d5a98280c6e26194d568ca475de799 Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Tue, 11 Nov 2025 16:38:15 +0100 Subject: [PATCH 014/118] Extend support for ak1 / bk1 WMMA (#3073) * Extend AK1 / BK1 support: - Add support for AK1 != BK1 - Add support for AK1, BK1 > 8 - Introduce KInner template parameter for pipelines when loading multiple tiles with one instruction * fix clang format --- example/01_gemm/gemm_wmma_fp8_v3.cpp | 10 +- .../blockwise_gemm_pipeline_wmma_selector.hpp | 3 + .../blockwise_gemm_pipeline_wmmaops_base.hpp | 47 +-- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 302 ++++++++++-------- .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 250 +++++++++------ .../gridwise_ab_transfer_thread_tiles.hpp | 98 +++++- .../grid/gridwise_ab_transfer_wave_tiles.hpp | 6 +- ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 159 ++++++--- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 25 +- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 14 + ...mm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp | 4 +- ...emm_wmma_universal_f16_f8_f16_km_kn_mn.hpp | 3 +- ...emm_wmma_universal_f16_f8_f16_km_nk_mn.hpp | 3 +- ...emm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp | 3 +- ...emm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_km_kn_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_km_nk_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp | 3 +- ...emm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp | 3 +- ...emm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp | 3 +- 24 files changed, 621 insertions(+), 339 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp8_v3.cpp b/example/01_gemm/gemm_wmma_fp8_v3.cpp index 0376820b7b..2f8eac113b 100644 --- a/example/01_gemm/gemm_wmma_fp8_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp8_v3.cpp @@ -13,7 +13,7 @@ using CDataType = ck::bhalf_t; using ComputeTypeA = ck::f8_t; using ComputeTypeB = ck::f8_t; -using ALayout = Row; +using ALayout = Col; using BLayout = Col; using CLayout = Row; @@ -30,13 +30,13 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 64, - 8, 8, + 16, 16, // AK1, BK1 16, 16, 4, 2, + S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 4, 16, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 0, - S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 0, + 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp index 8cff087ddb..89952910e6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp @@ -28,6 +28,7 @@ template constexpr auto BlockGemmPipeline_Selector() { @@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, + KInner, TransposeC>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) @@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, + KInner, TransposeC>{}; } else diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index 265db9166a..abc9720714 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -30,6 +30,7 @@ template struct BlockwiseGemmWmmaops_pipeline_base { @@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; using ThisThreadBlock = ThisThreadBlock; @@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base static constexpr index_t B_KRow = 1; #endif - static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5); - static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5); + static constexpr auto wmma_gemm = WmmaGemm{}; + + static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner; + static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread); + static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread); static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!"); static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!"); - - static constexpr auto wmma_gemm = - WmmaGemm{}; - static constexpr index_t KRepeat = KPerBlock / KPack; static constexpr auto WmmaK = Number{}; @@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base const auto wmma_krow = 0; #endif - // |KRepeat |MRepeat|MWave |KRow |MLane |KPack - return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); + return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); } __device__ static auto CalculateBThreadOriginDataIndex() @@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base const auto wmma_krow = 0; #endif - // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack - return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); + return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); } template @@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base return make_tuple(c_thread_m, c_thread_n); } - using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + using Tuple7 = decltype(CalculateAThreadOriginDataIndex()); /** * @brief Constructor for BlockwiseGemmWmmaops_pipeline_base. @@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base * repeat dimensions. */ __host__ __device__ - BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), - Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(), + Tuple7 b_origin = CalculateBThreadOriginDataIndex()) : a_thread_copy_(a_origin), b_thread_copy_(b_origin) { static_assert(AWmmaTileDesc::IsKnownAtCompileTime() && @@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base Number{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); static constexpr auto b_thread_desc_ = @@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base Number{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); // C[M, N, NumRegWmma] @@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), - Sequence, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), - Sequence, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 5d7c570428..5f731933e2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -32,6 +32,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 { @@ -55,6 +56,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; using Base::I1; - using Base::WaveSize; using typename Base::HotLoopInstList; using Base::A_K1; @@ -187,6 +191,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -211,27 +217,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0, I0), - a_thread_buf); - + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0, I0, I0, I0), + a_thread_buf); if constexpr(m0 == I0) { if constexpr(ck::is_same::value == true) { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple( - Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0, I0), + b_thread_buf); }); } else @@ -239,45 +241,60 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple( - Number{}, n0, I0, I0, I0, I0), + make_tuple(I0, n0, k0, I0, I0, I0, I0), b_block_buf, b_scale_struct.b_scale_thread_bufs( I0)[Number{}], b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), + make_tuple(I0, n0, I0, I0, I0, I0, I0), b_thread_buf); }); } } - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, I0, I0, I0, I0, Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + I0, + I0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + I0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, I0, I0, I0, Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -324,8 +341,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + static_for<0, KInner, 1>{}([&](auto) { + static_for<0, NRepeat, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); }); }); }); @@ -348,20 +367,20 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, I1, I1, Number{})); + make_tuple(Number{}, I1, I1, I1, I1, I1, Number{})); // B[NRepeat, N1, N2, KPack] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, I1, Number{})); + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -370,9 +389,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; @@ -399,6 +418,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; using Base::I1; @@ -532,6 +555,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -557,33 +582,22 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_offset) { static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, - m0, - I0, - I0, - I0, - I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0_inner, I0, I0, I0), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0_inner, I0, I0, I0, I0), + a_thread_buf); }); if constexpr(ck::is_same::value == true) { static_for<0, NRepeat, 1>{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - n0, - I0, - I0, - I0, - I0), + make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0), b_block_buf, b_thread_desc_, - make_tuple(I0, n0, k0_inner, I0, I0, I0), + make_tuple(I0, n0, k0_inner, I0, I0, I0, I0), b_thread_buf); }); } @@ -592,18 +606,13 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - n0, - I0, - I0, - I0, - I0), + make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0), b_block_buf, b_scale_struct.b_scale_thread_bufs(I0)[Number< n0 * BScaleStruct::num_scale_k_block + (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}], b_thread_desc_, - make_tuple(I0, n0, k0_inner, I0, I0, I0), + make_tuple(I0, n0, k0_inner, I0, I0, I0, I0), b_thread_buf); }); } @@ -622,62 +631,69 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_inner) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0_inner, - I0, - I0, - Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard. + // B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts. + // It is performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0_offset + k0_inner == KRepeat - 1 && + m0 == MRepeat - 1 && n0 == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0_inner, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard. - // B) reduce VMEM FIFO congestion by applying small delays to - // different wavefronts. - // It is performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 && - n0 == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } }); }); }); @@ -729,12 +745,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); static constexpr auto b_thread_desc_ = @@ -743,12 +761,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); using AThreadCopy = @@ -756,9 +776,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -767,9 +787,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 83dadb2175..cbe13b6e00 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -32,6 +32,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v3 { @@ -55,6 +56,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v3 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; @@ -290,40 +295,37 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); if constexpr(ck::is_same_v) { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); }); } else { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_scale_struct.b_scale_thread_bufs( - I0)[Number{}], - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); }); } }); @@ -364,6 +366,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -424,41 +429,48 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -489,31 +501,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -531,31 +559,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 465952e285..23f16d38e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -17,6 +17,9 @@ template {}, KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + if constexpr(KInner > 1) + { + // KPack = KInner * KPerWmma + // K1 = KInner * KPerWmmaBlk + // Each thread loads multiple tiles with one instruction + // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // KPack = KPerWmma (KInner == 1) + if constexpr(ABK1 <= KPerWmmaBlk) + { + // K1 <= single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1 + // (rest of the single WMMA tile for single thread) and then over KRow + // (rest of the single WMMA tile for single wave) + // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // K1 > single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 > single tile, each thread loads KPerWmmaBlk and the next + // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout). + // This layout is needed to support for example AK1 > single tile and + // BK1 <= single tile in the same gemm + // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - + // K1 + constexpr auto desc1 = transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{})); + + return transform_tensor_descriptor( + desc1, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + } + } } __device__ static constexpr auto GetBlockStep() diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index 68476ef3bf..a36ccd43ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -313,14 +313,16 @@ struct ABTransferWaveTiles // This is a block descriptor used to read LDS memory into register // It's defined in a way consistent with the existing implementation to // avoid changes in the pipelines - return make_naive_tensor_descriptor(make_tuple(Number{}, + return make_naive_tensor_descriptor(make_tuple(I1, Number{}, + Number{}, Number{}, Number{}, Number{}, Number{}), - make_tuple(Number{}, + make_tuple(I0, Number{}, + Number{}, Number{}, Number{}, Number{}, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index fa7eb4faaa..38ebdab65e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -109,9 +109,20 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); - // TODO: I am pretty sure this is always 16 and *should* always be 16. - static constexpr auto KPack = - math::integer_least_multiple(math::integer_least_multiple(AK1Value, BK1Value), 16); + static constexpr index_t KPerWmmaBlk = + WmmaSelector::selected_wmma + .k_per_blk; + + static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk); + + static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk); + + static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB); + + static constexpr index_t KPack = + KInner * + WmmaSelector::selected_wmma + .k_per_wmma; using ThisThreadBlock = ThisThreadBlock; @@ -201,54 +212,115 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 return b1_block_copy_step; } + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + + if constexpr(KInner > 1) + { + // KPack = KInner * KPerWmma + // K1 = KInner * KPerWmmaBlk + // Each thread loads multiple tiles with one instruction + // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // KPack = KPerWmma (KInner == 1) + if constexpr(K1 <= KPerWmmaBlk) + { + // K1 <= single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1 + // (rest of the single WMMA tile for single thread) and then over KRow + // (rest of the single WMMA tile for single wave) + // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // K1 > single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 > single tile, each thread loads KPerWmmaBlk and the next + // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout). + // This layout is needed to support for example AK1 > single tile and + // BK1 <= single tile in the same gemm + // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - + // K1 + constexpr auto desc1 = transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{})); + + return transform_tensor_descriptor( + desc1, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + } + } + } + template __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) { - constexpr auto a_wave_desc = [&]() { - // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto A_KRow = I2; -#else - constexpr auto A_KRow = I1; -#endif - return transform_tensor_descriptor( - ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - }(); - - return a_wave_desc; + return MakeWmmaTileDescriptor(ABlockDesc_{}); } template __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&) { - constexpr auto b0_wave_desc = [&]() { - // BK0_L_BK1 -> BK0_LRepeat_Lwaves_BKRow_LPerWmma_BK1 - constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto B_KRow = I2; -#else - constexpr auto B_KRow = I1; -#endif - return transform_tensor_descriptor( - B0BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - }(); - - return b0_wave_desc; + return MakeWmmaTileDescriptor(B0BlockDesc_{}); } template @@ -356,6 +428,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 MRepeat, LRepeat, KPack, + KInner, true>())>; // TransposeC (must be true to work), C' = B' x A' // block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 7a5e324468..56f09cee96 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -151,10 +151,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), + static constexpr index_t KPerWmmaBlk = WmmaSelector::selected_wmma - .k_per_wmma); + .k_per_blk; + + static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk); + + static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk); + + static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB); + + static constexpr index_t KPack = + KInner * + WmmaSelector::selected_wmma + .k_per_wmma; using ThisThreadBlock = ThisThreadBlock; @@ -218,6 +228,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base KPerBlock, MPerWmma, AK1Value, + KPack, + KInner, + KPerWmmaBlk, UseBlockPaddingA, PermuteA, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -251,6 +264,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base KPerBlock, NPerWmma, BK1Value, + KPack, + KInner, + KPerWmmaBlk, UseBlockPaddingB, PermuteB, BBlockTransferThreadClusterLengths_BK0_N_BK1, @@ -563,7 +579,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base NPerWmma, MRepeat, NRepeat, - KPack>())>; + KPack, + KInner>())>; // Used to create obj in global function and pass it to Run method using EpilogueCShuffle = diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index bca68764f9..55ede990af 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -95,6 +95,7 @@ struct wmma_type __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index 71b5c5e7cf..806b6e684d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -48,7 +48,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index f4489dc45f..4516d06492 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -50,7 +50,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 423f86365c..5ace0594f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -53,7 +53,9 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp index 2eb28958e6..27deab1c8c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -56,7 +56,9 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp index d10b9facd5..bd5c7d8783 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp @@ -48,7 +48,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp index d9d16ede65..1956d1a951 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp index 9277e5e901..934c6aa7ef 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp index e97a649c19..9860b81b78 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp index c8f1b85ddb..4d7169565a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp index fc0220a502..3728368bc4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp index b87cf64b0f..3506575f5d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp index 31ad66409e..eef0d6de6a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp @@ -50,7 +50,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp index 4c37c398fe..2418be62b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -55,7 +55,8 @@ using device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, F8, F8> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp index 6b5314b701..38f2869303 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, F8, F8> // clang-format on >; } // namespace instance From 06c651b100c9dc50753277069bdc68411da7ca1a Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Tue, 11 Nov 2025 07:42:26 -0800 Subject: [PATCH 015/118] formatting (#3182) --- include/ck_tile/ops/gemm_quant.hpp | 1 + .../block/block_gemm_quant_common.hpp | 38 +++++++++++++++++++ ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 18 ++------- .../block_universal_gemm_as_aquant_bs_cr.hpp | 17 ++------- .../block_universal_gemm_as_bs_bquant_cr.hpp | 17 ++------- 5 files changed, 48 insertions(+), 43 deletions(-) create mode 100644 include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 3273131875..3e16d937cb 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp new file mode 100644 index 0000000000..d695888b88 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Common utilities for quantized GEMM block operations +template +struct BlockGemmQuantCommon +{ + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemmType::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index df55081b69..2d92745f75 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -100,21 +101,8 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 8b95ec6ddf..1f72f4dc12 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -543,20 +544,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } template diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 9db444b57f..660c30aa6e 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -376,20 +377,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } template From aa1fb29aa102d937e061f138ecb22ef81e7a8fcd Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 11 Nov 2025 07:44:38 -0800 Subject: [PATCH 016/118] Bump commit ref for TheRock in workflows (#3189) * Bump commit ref for TheRock in workflows * Update to more recent commit (could also `rm` the patch) * Revert "Update to more recent commit (could also `rm` the patch)" This reverts commit 4b9f4952ead77e068f5ab86a07701c7e9bed48cc. * Rm patch that no longer applies * Fix post_build_upload flag name * Fix artifact_group plumbing for setup test env --- .github/workflows/therock-ci-linux.yml | 6 ++++-- .github/workflows/therock-test-component.yml | 4 ++-- .github/workflows/therock-test-packages.yml | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index f4d0c0063c..86d134e456 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -53,8 +53,8 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit path: "TheRock" + ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit - name: Setup ccache run: | @@ -77,6 +77,8 @@ jobs: - name: Patch rocm-libraries run: | git config --global --add safe.directory '*' + # Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo + rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps @@ -128,7 +130,7 @@ jobs: run: | python3 TheRock/build_tools/github_actions/post_build_upload.py \ --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ + --artifact-group ${{ env.AMDGPU_FAMILIES }} \ --build-dir TheRock/build \ --upload diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml index 1ccc1d57bc..27eff4fdb0 100644 --- a/.github/workflows/therock-test-component.yml +++ b/.github/workflows/therock-test-component.yml @@ -51,13 +51,13 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: "ROCm/TheRock" - ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit + ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit - name: Run setup test environment workflow uses: './.github/actions/setup_test_environment' with: ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} - AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + ARTIFACT_GROUP: ${{ inputs.amdgpu_families }} OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} VENV_DIR: ${{ env.VENV_DIR }} FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }} diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index efb5a6b1a0..81632fce48 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit + ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit - name: "Configuring CI options" env: From 88e3212fccf2a879c0e718deecc28caff453bb29 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 11 Nov 2025 11:17:24 -0500 Subject: [PATCH 017/118] chore(copyright): update copyright header for tile_engine directory (#3180) --- tile_engine/ops/commons/test_benchmark.sh | 3 +++ tile_engine/ops/commons/test_validation.py | 3 +++ tile_engine/ops/commons/validation_utils.py | 2 +- tile_engine/ops/gemm/codegen_utils.py | 2 +- tile_engine/ops/gemm/gemm_benchmark.hpp | 2 +- tile_engine/ops/gemm/gemm_benchmark.py | 2 +- tile_engine/ops/gemm/gemm_benchmark_single.cpp | 2 +- tile_engine/ops/gemm/gemm_common.hpp | 2 +- tile_engine/ops/gemm/gemm_instance_builder.py | 3 +++ tile_engine/ops/gemm/gemm_profiler.hpp | 2 +- tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp | 2 +- tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py | 2 +- .../ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp | 2 +- tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp | 2 +- tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py | 3 +++ tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp | 2 +- tile_engine/ops/gemm_preshuffle/commons/validation_utils.py | 2 +- tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp | 3 +++ tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py | 2 +- .../ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp | 2 +- tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp | 2 +- .../ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py | 4 ++-- tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp | 2 +- 23 files changed, 34 insertions(+), 19 deletions(-) diff --git a/tile_engine/ops/commons/test_benchmark.sh b/tile_engine/ops/commons/test_benchmark.sh index 1fb7c163af..e2e0324da8 100755 --- a/tile_engine/ops/commons/test_benchmark.sh +++ b/tile_engine/ops/commons/test_benchmark.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Test script for tile engine GEMM benchmarks # This script demonstrates how to run the new individual benchmark executables diff --git a/tile_engine/ops/commons/test_validation.py b/tile_engine/ops/commons/test_validation.py index 79f24265f1..46fb008c27 100644 --- a/tile_engine/ops/commons/test_validation.py +++ b/tile_engine/ops/commons/test_validation.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + """ Test script to verify that the validation logic is working correctly. """ diff --git a/tile_engine/ops/commons/validation_utils.py b/tile_engine/ops/commons/validation_utils.py index 3eb7bf8b57..5787446e8c 100644 --- a/tile_engine/ops/commons/validation_utils.py +++ b/tile_engine/ops/commons/validation_utils.py @@ -1,6 +1,6 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. """ Validation utilities for GEMM kernel generation. diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 0020fccf05..eecc2228a6 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -1,5 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # -*- coding: utf-8 -*- diff --git a/tile_engine/ops/gemm/gemm_benchmark.hpp b/tile_engine/ops/gemm/gemm_benchmark.hpp index 0e2619785e..7c8df32ad8 100644 --- a/tile_engine/ops/gemm/gemm_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_benchmark.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm/gemm_benchmark.py b/tile_engine/ops/gemm/gemm_benchmark.py index 9f323f2640..cc04dbe0db 100755 --- a/tile_engine/ops/gemm/gemm_benchmark.py +++ b/tile_engine/ops/gemm/gemm_benchmark.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. import sys import json diff --git a/tile_engine/ops/gemm/gemm_benchmark_single.cpp b/tile_engine/ops/gemm/gemm_benchmark_single.cpp index bbcc6eb505..6323c066a1 100644 --- a/tile_engine/ops/gemm/gemm_benchmark_single.cpp +++ b/tile_engine/ops/gemm/gemm_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp index 4732f2a1ba..899221547f 100644 --- a/tile_engine/ops/gemm/gemm_common.hpp +++ b/tile_engine/ops/gemm/gemm_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 1aff42b902..8885c821c1 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import json diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 575e5240a8..3c6bbc34d3 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp index 53dcdb5e1f..f8c196e32a 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py index fb81b9c2c2..044e08baca 100755 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. import sys import json diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp index 032a625354..41d2f736e1 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp index 4732f2a1ba..899221547f 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py index 3f7858f146..cc167fb75f 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import json diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp index 8e19c11c7d..3a2cdc71fe 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py index b38ff5dffb..70ce3b0d72 100644 --- a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py +++ b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py @@ -1,6 +1,6 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. """ Validation utilities for GEMM kernel generation. diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp index 77a9f26527..748fe581d3 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp @@ -1,3 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once #include "ck_tile/core.hpp" diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py index 0217a439f2..d8892be7d6 100755 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. import sys import json diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp index 1f03d1cf9b..4fbb25f0c9 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index abaa5ebd46..1b2cfe3735 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index 57c250f57e..9ce6d8cb25 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -1,5 +1,5 @@ -## Copyright © Advanced Micro Devices, Inc. or its affiliates. -## SPDX-License-Identifier: MIT +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT import argparse import os diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp index 85b731c231..739bd7e677 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -1,4 +1,4 @@ -// Copyright © Advanced Micro Devices, Inc. or its affiliates. +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once From 1b1c46e508c1fd40a03f54114b6b78629032fb4f Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 12 Nov 2025 00:23:57 +0800 Subject: [PATCH 018/118] [CK_TILE] Fix gemm_quant (#3186) --- .../38_block_scale_gemm/CMakeLists.txt | 2 +- .../38_block_scale_gemm/gemm_quant_basic.cpp | 4 + .../38_block_scale_gemm/gemm_utils.hpp | 8 ++ include/ck_tile/host/tensor_shuffle_utils.hpp | 98 ++++++++++++++----- .../gemm/warp/warp_gemm_attribute_wmma.hpp | 1 + ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 4 +- .../block_universal_gemm_as_aquant_bs_cr.hpp | 11 +-- .../block_universal_gemm_as_bs_bquant_cr.hpp | 6 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 5 +- .../pipeline/tile_gemm_quant_traits.hpp | 5 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 2 +- .../gemm_block_scale/test_gemm_quant_base.hpp | 14 ++- .../test_gemm_quant_fixtures.hpp | 24 +++-- 13 files changed, 135 insertions(+), 49 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 7358d4d749..b1ae9369a2 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -5,7 +5,7 @@ endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index b22596537f..d605a2b780 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -419,6 +419,10 @@ int dispatch_group_size_ct(int m, int n, int k, F&& f) int main(int argc, char* argv[]) { +#if CK_TILE_USE_WMMA + return !run_gemm_example(argc, argv); +#else // Use non-preshuffled GemmConfig for 2D block scale support return !run_gemm_example(argc, argv); +#endif } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 589caf88f4..1839c7f98d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -216,6 +216,14 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +template +struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill +{ + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template @@ -55,21 +82,46 @@ template auto shuffle_b_permuteN(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; - - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, - NRepeat, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + } } } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index 90f6204ff3..dd2931f6b7 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -79,6 +79,7 @@ struct WarpGemmAttributeWmma static constexpr index_t kM = Impl::kM; static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK; + static constexpr index_t kCMLane = Impl::kCMLane; static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index 2d92745f75..6422c07e1d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -82,11 +82,11 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 1f72f4dc12..bbdd3128bf 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -25,13 +25,11 @@ struct BlockGemmAQuantBase float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { @@ -349,7 +347,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, // 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3. - constexpr uint32_t kTileRowsOfCPerThread = 4; + constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8; decltype(threadIdx.x) pull_from_lane = 0; if constexpr(WarpGemm::kM == 16) { @@ -410,7 +408,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // desired row coefficient auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; - constexpr uint32_t kTileRows = 4; + constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8; + ; constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane; // Multiply by 4 because output is stored in tiles of 4 diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 660c30aa6e..28ae709bf0 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -25,13 +25,11 @@ struct BlockGemmBQuantBase float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 36cbb87877..15d2727f3b 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -240,7 +240,10 @@ struct QuantGemmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static auto BlockSize() + { + return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize); + } CK_TILE_HOST static constexpr QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index c4429b76f9..3a5b86382d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -41,7 +41,8 @@ template + bool UsePersistentKernel_ = false, + int VectorSize_ = 16> struct TileGemmQuantTraits { static constexpr bool kPadM = kPadM_; @@ -50,7 +51,7 @@ struct TileGemmQuantTraits static constexpr QuantType kQuantType = QuantType_; - static constexpr int _VectorSize = 16; + static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; using ALayout = ALayout_; diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 3a49e69c37..1c4a25c8bd 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -5,7 +5,7 @@ endif() list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # Typed Test Suite for GEMM Quantization add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 6454101daf..6226a2de9e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -69,7 +69,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; - + // WP pipeline requires per-thread tile size aligned to Problem::VectorLoadSize. + // static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % + // VectorLoadSize == 0). gfx9 cards match the requirements but it fails on gfx12. so we only + // need to check the limitation on RDNA cards, i.e. assume wave size is 32. + constexpr ck_tile::index_t WaveSize = 32; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile); + constexpr bool SupportVectorSize16 = + (M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0; + constexpr int VectorSize = PreshuffleB ? (SupportVectorSize16 ? 16 : 8) : 16; using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, @@ -89,7 +97,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test ALayout, BLayout, GemmConfig::TransposeC, - DoubleSmemBuffer>; + DoubleSmemBuffer, + false, + VectorSize>; // Let the derived class create the appropriate pipeline and epilogue static_cast(this) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index cabc0ec02c..5aac095514 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -7,6 +7,16 @@ #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else + return is_8bit ? 64 : 32; +#endif +} + struct GemmConfigBase { static constexpr bool kPadM = false; @@ -40,7 +50,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleQuant : public GemmConfigBase @@ -75,7 +85,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleBPrefill : public GemmConfigBase @@ -94,7 +104,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill @@ -132,7 +142,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase Date: Tue, 11 Nov 2025 14:26:01 -0500 Subject: [PATCH 019/118] chore(copyright): update copyright header for script directory (#3184) * chore(copyright): update copyright header for tile_engine directory * chore(copyright): update copyright header for script directory --------- Co-authored-by: Vidyasagar Ananthan --- script/check_copyright_year.sh | 3 +++ script/clang-format-overwrite.sh | 3 +++ script/cmake-ck-dev.sh | 3 +++ script/convert_miopen_driver_to_profiler.py | 3 ++- script/count_vgpr.sh | 3 +++ .../generate_list_of_files_not_referenced_in_tests.py | 4 ++-- script/dependency-parser/main.py | 2 +- script/dependency-parser/src/enhanced_ninja_parser.py | 2 +- script/dependency-parser/src/selective_test_filter.py | 2 +- script/gemm_profile.sh | 2 +- script/hipclang_opt.sh | 3 +++ script/install_precommit.sh | 3 +++ script/launch_tests.sh | 2 +- script/ninja_json_converter.py | 2 +- script/process_perf_data.py | 3 +++ script/process_perf_data.sh | 3 +++ script/process_qa_data.sh | 3 +++ script/profile_batched_gemm.sh | 3 +++ script/profile_gemm.sh | 3 +++ script/profile_gemm_bilinear.sh | 3 +++ script/profile_grouped_conv_bwd_data.sh | 3 +++ script/profile_grouped_conv_bwd_weight.sh | 3 +++ script/profile_grouped_conv_fwd.sh | 3 +++ script/profile_grouped_conv_fwd_outelementop.sh | 3 +++ script/profile_grouped_gemm.sh | 3 +++ script/profile_mixed_gemm.sh | 3 +++ script/profile_onnx_gemm.sh | 3 +++ script/profile_permute_scale.sh | 3 +++ script/profile_reduce_no_index.sh | 3 +++ script/profile_reduce_with_index.sh | 3 +++ script/profile_resnet50.sh | 3 +++ script/profile_splitK_gemm.sh | 3 +++ script/remod_for_ck_tile.py | 3 +++ script/remove_exec_bit.sh | 2 +- script/run_ck_profiler_gemm_with_csv_shapes.py | 2 +- script/run_full_performance_tests.sh | 3 +++ script/run_gemm_performance_tests.sh | 3 +++ script/run_performance_tests.sh | 3 +++ script/sccache_wrapper.sh | 3 +++ script/test_convnd_fwd.sh | 3 +++ script/test_reduce_no_index.sh | 3 +++ script/uninstall_precommit.sh | 3 +++ 42 files changed, 108 insertions(+), 11 deletions(-) diff --git a/script/check_copyright_year.sh b/script/check_copyright_year.sh index f7709472ef..1b63c6b711 100755 --- a/script/check_copyright_year.sh +++ b/script/check_copyright_year.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + current_year=$(date +%Y) exit_code=0 diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index 74391ded28..23b57b9935 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | grep -v 'build/' | grep -v 'include/rapidjson'| xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|.hpp|.inc|include/rapidjson/")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 6220009b03..9643af1de0 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # exit when a command exits with non-zero status; also when an unbound variable is referenced set -eu # pipefail is supported by many shells, not supported by sh and dash diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py index d814e0719c..5aff9c0a7f 100644 --- a/script/convert_miopen_driver_to_profiler.py +++ b/script/convert_miopen_driver_to_profiler.py @@ -1,5 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + # Convert miopen driver command to ck Profiler # Example: python3 ../script/convert_miopen_driver_to_profiler.py # /opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 -k 64 -y 3 -x 3 diff --git a/script/count_vgpr.sh b/script/count_vgpr.sh index 07debc53a8..651a894db6 100755 --- a/script/count_vgpr.sh +++ b/script/count_vgpr.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + FILE=$1 for num in {0..255} diff --git a/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py b/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py index 8419b9491e..58bb9e8e93 100644 --- a/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py +++ b/script/dependency-parser/generate_list_of_files_not_referenced_in_tests.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT -## Copyright © Advanced Micro Devices, Inc. or its affiliates. -## SPDX-License-Identifier: MIT # This script generate list of files that are not referenced from any test (list in JSON format) # Script only looks at not referenced files from three directories: include, library and profiler diff --git a/script/dependency-parser/main.py b/script/dependency-parser/main.py index 623ae05afd..f345362b26 100644 --- a/script/dependency-parser/main.py +++ b/script/dependency-parser/main.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/script/dependency-parser/src/enhanced_ninja_parser.py b/script/dependency-parser/src/enhanced_ninja_parser.py index ff6344a4c1..2ac8e8537a 100644 --- a/script/dependency-parser/src/enhanced_ninja_parser.py +++ b/script/dependency-parser/src/enhanced_ninja_parser.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/script/dependency-parser/src/selective_test_filter.py b/script/dependency-parser/src/selective_test_filter.py index d3228ef624..83f7f7eebe 100644 --- a/script/dependency-parser/src/selective_test_filter.py +++ b/script/dependency-parser/src/selective_test_filter.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/script/gemm_profile.sh b/script/gemm_profile.sh index 89419ca711..d3d66bcaa9 100755 --- a/script/gemm_profile.sh +++ b/script/gemm_profile.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT BIN=./bin/tile_example_gemm_weight_preshuffle diff --git a/script/hipclang_opt.sh b/script/hipclang_opt.sh index c51bd51d97..ba5636eeb6 100755 --- a/script/hipclang_opt.sh +++ b/script/hipclang_opt.sh @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + rm *.ll *.s BC_FILE=$1 diff --git a/script/install_precommit.sh b/script/install_precommit.sh index 545dcfa666..f80b06a95a 100755 --- a/script/install_precommit.sh +++ b/script/install_precommit.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + run_and_check() { "$@" status=$? diff --git a/script/launch_tests.sh b/script/launch_tests.sh index 52151b71f6..1911613023 100755 --- a/script/launch_tests.sh +++ b/script/launch_tests.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT # Get the directory where the script is located diff --git a/script/ninja_json_converter.py b/script/ninja_json_converter.py index e68f7ccfa3..5e974cf730 100644 --- a/script/ninja_json_converter.py +++ b/script/ninja_json_converter.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/script/process_perf_data.py b/script/process_perf_data.py index b35ba64041..5f81512a4c 100644 --- a/script/process_perf_data.py +++ b/script/process_perf_data.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import io import argparse diff --git a/script/process_perf_data.sh b/script/process_perf_data.sh index 50c84924f5..4786ddded0 100755 --- a/script/process_perf_data.sh +++ b/script/process_perf_data.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd need the following python packages: diff --git a/script/process_qa_data.sh b/script/process_qa_data.sh index 420453cddc..d56ef5c1ec 100755 --- a/script/process_qa_data.sh +++ b/script/process_qa_data.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd need the following python packages: diff --git a/script/profile_batched_gemm.sh b/script/profile_batched_gemm.sh index f90baaed68..bb7d61deec 100755 --- a/script/profile_batched_gemm.sh +++ b/script/profile_batched_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_gemm.sh b/script/profile_gemm.sh index b88159e74d..f766ca50fa 100755 --- a/script/profile_gemm.sh +++ b/script/profile_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_gemm_bilinear.sh b/script/profile_gemm_bilinear.sh index e6edefae85..057d7d7e49 100755 --- a/script/profile_gemm_bilinear.sh +++ b/script/profile_gemm_bilinear.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 DRIVER="../build/bin/ckProfiler" diff --git a/script/profile_grouped_conv_bwd_data.sh b/script/profile_grouped_conv_bwd_data.sh index a1d2f450c9..3805ed86cd 100755 --- a/script/profile_grouped_conv_bwd_data.sh +++ b/script/profile_grouped_conv_bwd_data.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_grouped_conv_bwd_weight.sh b/script/profile_grouped_conv_bwd_weight.sh index e3652202d4..146431621c 100755 --- a/script/profile_grouped_conv_bwd_weight.sh +++ b/script/profile_grouped_conv_bwd_weight.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_grouped_conv_fwd.sh b/script/profile_grouped_conv_fwd.sh index 9a974525ad..8491aecf9e 100755 --- a/script/profile_grouped_conv_fwd.sh +++ b/script/profile_grouped_conv_fwd.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_grouped_conv_fwd_outelementop.sh b/script/profile_grouped_conv_fwd_outelementop.sh index ac444a25c2..a0df8cd4c5 100755 --- a/script/profile_grouped_conv_fwd_outelementop.sh +++ b/script/profile_grouped_conv_fwd_outelementop.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_grouped_gemm.sh b/script/profile_grouped_gemm.sh index 8adb7c81ac..fe452d5cab 100755 --- a/script/profile_grouped_gemm.sh +++ b/script/profile_grouped_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_mixed_gemm.sh b/script/profile_mixed_gemm.sh index 383c7ea36e..a867bf3a77 100755 --- a/script/profile_mixed_gemm.sh +++ b/script/profile_mixed_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_onnx_gemm.sh b/script/profile_onnx_gemm.sh index c2721e7f59..ea18fc761e 100755 --- a/script/profile_onnx_gemm.sh +++ b/script/profile_onnx_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 DRIVER="../build/bin/ckProfiler" diff --git a/script/profile_permute_scale.sh b/script/profile_permute_scale.sh index 945d10f47b..31d6a06c5e 100755 --- a/script/profile_permute_scale.sh +++ b/script/profile_permute_scale.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_reduce_no_index.sh b/script/profile_reduce_no_index.sh index 66bfe1dcd3..3bae07906b 100755 --- a/script/profile_reduce_no_index.sh +++ b/script/profile_reduce_no_index.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + DRIVER="../build/bin/ckProfiler" VERIFY="-v $1" INIT=$2 diff --git a/script/profile_reduce_with_index.sh b/script/profile_reduce_with_index.sh index 43543f4430..943a590528 100755 --- a/script/profile_reduce_with_index.sh +++ b/script/profile_reduce_with_index.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + DRIVER="../build/bin/ckProfiler" VERIFY="-v $1" INIT=$2 diff --git a/script/profile_resnet50.sh b/script/profile_resnet50.sh index b55cb2ccef..ec6b32c0c8 100755 --- a/script/profile_resnet50.sh +++ b/script/profile_resnet50.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/profile_splitK_gemm.sh b/script/profile_splitK_gemm.sh index d62f0e4753..843d59c918 100755 --- a/script/profile_splitK_gemm.sh +++ b/script/profile_splitK_gemm.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## GPU visibility export HIP_VISIBLE_DEVICES=0 diff --git a/script/remod_for_ck_tile.py b/script/remod_for_ck_tile.py index 7601c9d619..feb50dc290 100755 --- a/script/remod_for_ck_tile.py +++ b/script/remod_for_ck_tile.py @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os root_dir = os.getcwd() diff --git a/script/remove_exec_bit.sh b/script/remove_exec_bit.sh index 2926683d6a..0b3ca80422 100755 --- a/script/remove_exec_bit.sh +++ b/script/remove_exec_bit.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT for file in $(git diff --cached --name-only --diff-filter=ACM | grep -E '\.(cpp|hpp|txt|inc)$'); do diff --git a/script/run_ck_profiler_gemm_with_csv_shapes.py b/script/run_ck_profiler_gemm_with_csv_shapes.py index eb0eb9c920..2590e3942e 100644 --- a/script/run_ck_profiler_gemm_with_csv_shapes.py +++ b/script/run_ck_profiler_gemm_with_csv_shapes.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # -*- coding: utf-8 -*- diff --git a/script/run_full_performance_tests.sh b/script/run_full_performance_tests.sh index 508200b21a..55740da097 100755 --- a/script/run_full_performance_tests.sh +++ b/script/run_full_performance_tests.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ # you would also need to set up some environment variables in order to diff --git a/script/run_gemm_performance_tests.sh b/script/run_gemm_performance_tests.sh index 12adad30f8..c72b2a760b 100755 --- a/script/run_gemm_performance_tests.sh +++ b/script/run_gemm_performance_tests.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ # run the script as "./run_gemm_performance_tests.sh diff --git a/script/run_performance_tests.sh b/script/run_performance_tests.sh index 4e13b59d34..9163e6d693 100755 --- a/script/run_performance_tests.sh +++ b/script/run_performance_tests.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # # in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ # run the script as "./run_performance_tests.sh diff --git a/script/sccache_wrapper.sh b/script/sccache_wrapper.sh index 30fd17e520..1a7e37881e 100755 --- a/script/sccache_wrapper.sh +++ b/script/sccache_wrapper.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set -e COMPILERS_HASH_DIR=${COMPILERS_HASH_DIR:-"/tmp/.sccache"} SCCACHE_EXTRAFILES=${SCCACHE_EXTRAFILES:-"${COMPILERS_HASH_DIR}/rocm_compilers_hash_file"} diff --git a/script/test_convnd_fwd.sh b/script/test_convnd_fwd.sh index 8bd2c2fc33..d716caac15 100644 --- a/script/test_convnd_fwd.sh +++ b/script/test_convnd_fwd.sh @@ -1,4 +1,7 @@ #!/usr/bin/env bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # set -e diff --git a/script/test_reduce_no_index.sh b/script/test_reduce_no_index.sh index b956303837..717a872c45 100755 --- a/script/test_reduce_no_index.sh +++ b/script/test_reduce_no_index.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + ## The following will be used for CI diff --git a/script/uninstall_precommit.sh b/script/uninstall_precommit.sh index b0d4d15166..394425acdd 100755 --- a/script/uninstall_precommit.sh +++ b/script/uninstall_precommit.sh @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + pre-commit uninstall From c54ecd905b07849076069d56c284472230564568 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 11 Nov 2025 14:27:33 -0500 Subject: [PATCH 020/118] docs: update ckProfiler readme with selective building option (#3140) * docs: update ckProfiler readme with selective building option * docs: add list of operations for ckProfiler --- profiler/README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/profiler/README.md b/profiler/README.md index 05bbc7b4f9..86f668eacb 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -1,5 +1,23 @@ [Back to the main page](../README.md) # Composable Kernel profiler + +## Building Specific Profilers +To reduce build time, filter which operations to compile using CMake options: + +```bash +# Build all grouped_gemm variants (grouped_gemm, grouped_gemm_fastgelu, grouped_gemm_tile_loop, etc.) +cmake -DCK_PROFILER_OP_FILTER="grouped_gemm" .. + +# Build ONLY base grouped_gemm (excludes variants - use exact regex match with ^ and $) +cmake -DCK_PROFILER_OP_FILTER="^grouped_gemm$" .. +``` + +Both `CK_PROFILER_OP_FILTER` and `CK_PROFILER_INSTANCE_FILTER` accept regex patterns. Default builds all operations. + +To find the complete list of operations, run the following command: +```bash +find profiler/src -name "profile_*.cpp" | sed 's|profiler/src/profile_||' | sed 's|.cpp||' | sort +``` ## Profiler GEMM UNIVERSAL kernels ```bash # arg1: tensor operation (gemm_universal: Universal GEMM) From b145a5fe80d2f9d965f2c8555808017c3a660fc2 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 11 Nov 2025 15:15:49 -0500 Subject: [PATCH 021/118] Add CK Tile Tutorials Folder with GEMM and COPY Kernel (#3038) * feat: add tutorial folder with gemm tutorial * chore: move copy kernel from examples folder to tutorial * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tutorial/ck_tile/01_naive_gemm/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore: remove handdrawn images * docs: add write ups to explain the gemm kernel * docs: add about block level pipeline and static distributed tensors --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- CMakeLists.txt | 6 + example/ck_tile/CMakeLists.txt | 1 - tutorial/CMakeLists.txt | 15 + .../ck_tile/00_copy_kernel}/CMakeLists.txt | 6 +- .../ck_tile/00_copy_kernel}/README.md | 0 .../ck_tile/00_copy_kernel}/copy_basic.cpp | 22 +- .../ck_tile/00_copy_kernel}/copy_basic.hpp | 0 .../00_copy_kernel}/test_tile_example.sh | 2 +- .../01_naive_gemm/BLOCK_LEVEL_PIPELINE.md | 589 +++++++++++++++++ tutorial/ck_tile/01_naive_gemm/CMakeLists.txt | 7 + .../01_naive_gemm/HOST_LEVEL_PIPELINE.md | 618 ++++++++++++++++++ .../01_naive_gemm/KERNEL_ENTRY_POINT.md | 464 +++++++++++++ tutorial/ck_tile/01_naive_gemm/README.md | 150 +++++ tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md | 506 ++++++++++++++ ...e_gemm_block_pipeline_agmem_bgmem_creg.hpp | 165 +++++ ...ice_gemm_block_policy_agmem_bgmem_creg.hpp | 135 ++++ ...ce_gemm_host_pipeline_agmem_bgmem_creg.hpp | 92 +++ ...tice_gemm_host_policy_agmem_bgmem_creg.hpp | 51 ++ .../ck_tile/01_naive_gemm/practice_gemm.cpp | 131 ++++ .../ck_tile/01_naive_gemm/practice_gemm.hpp | 69 ++ .../ck_tile/01_naive_gemm/reference_gemm.hpp | 36 + ...ce_gemm_warp_pipeline_asmem_bsmem_creg.hpp | 195 ++++++ ...tice_gemm_warp_policy_asmem_bsmem_creg.hpp | 35 + tutorial/ck_tile/CMakeLists.txt | 7 + 24 files changed, 3287 insertions(+), 15 deletions(-) create mode 100644 tutorial/CMakeLists.txt rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/CMakeLists.txt (54%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/README.md (100%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/copy_basic.cpp (86%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/copy_basic.hpp (100%) rename {example/ck_tile/39_copy => tutorial/ck_tile/00_copy_kernel}/test_tile_example.sh (95%) create mode 100644 tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md create mode 100644 tutorial/ck_tile/01_naive_gemm/CMakeLists.txt create mode 100644 tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md create mode 100644 tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md create mode 100644 tutorial/ck_tile/01_naive_gemm/README.md create mode 100644 tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md create mode 100644 tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp create mode 100644 tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp create mode 100644 tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp create mode 100644 tutorial/ck_tile/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 049da5637f..7b4990dba4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -683,6 +683,12 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) PACKAGE_NAME examples ) add_subdirectory(example) + + add_subdirectory(tutorial) + rocm_package_setup_component(tutorials + LIBRARY_NAME composablekernel + PACKAGE_NAME tutorials + ) add_subdirectory(tile_engine) if(BUILD_TESTING) add_subdirectory(test) diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index a6cfcde86e..92ee0a4c31 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -25,7 +25,6 @@ add_subdirectory(22_gemm_multi_abd) add_subdirectory(35_batched_transpose) add_subdirectory(36_pooling) add_subdirectory(38_block_scale_gemm) -add_subdirectory(39_copy) add_subdirectory(40_streamk_gemm) add_subdirectory(41_batched_contraction) diff --git a/tutorial/CMakeLists.txt b/tutorial/CMakeLists.txt new file mode 100644 index 0000000000..a2f35ca53f --- /dev/null +++ b/tutorial/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include +) + +message(STATUS "Building tutorials...") +add_custom_target(tutorials) + +# add all tutorial subdir +file(GLOB dir_list LIST_DIRECTORIES true *) +FOREACH(subdir ${dir_list}) + if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt") + add_subdirectory(${subdir}) + ENDIF() +ENDFOREACH() diff --git a/example/ck_tile/39_copy/CMakeLists.txt b/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt similarity index 54% rename from example/ck_tile/39_copy/CMakeLists.txt rename to tutorial/ck_tile/00_copy_kernel/CMakeLists.txt index 98397a33d2..91dd036eff 100644 --- a/example/ck_tile/39_copy/CMakeLists.txt +++ b/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt @@ -1,7 +1,9 @@ -add_executable(tile_example_copy EXCLUDE_FROM_ALL copy_basic.cpp) +add_executable(tile_tutorial_copy_kernel EXCLUDE_FROM_ALL copy_basic.cpp) # Impact: This flag ensures that the compiler doesn't make # assumptions about memory aliasing that could interfere with Composable Kernel's explicit memory access patterns. -target_compile_options(tile_example_copy PRIVATE +target_compile_options(tile_tutorial_copy_kernel PRIVATE -mllvm -enable-noalias-to-md-conversion=0 ) + +add_dependencies(tutorials tile_tutorial_copy_kernel) diff --git a/example/ck_tile/39_copy/README.md b/tutorial/ck_tile/00_copy_kernel/README.md similarity index 100% rename from example/ck_tile/39_copy/README.md rename to tutorial/ck_tile/00_copy_kernel/README.md diff --git a/example/ck_tile/39_copy/copy_basic.cpp b/tutorial/ck_tile/00_copy_kernel/copy_basic.cpp similarity index 86% rename from example/ck_tile/39_copy/copy_basic.cpp rename to tutorial/ck_tile/00_copy_kernel/copy_basic.cpp index de91dc1be9..282e9ff8c1 100644 --- a/example/ck_tile/39_copy/copy_basic.cpp +++ b/tutorial/ck_tile/00_copy_kernel/copy_basic.cpp @@ -54,10 +54,10 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); // Define tile configuration - using ThreadTile = ck_tile::sequence<1, 4>; // per-thread tile size along M and N - using WaveTile = ck_tile::sequence<64, 4>; // wave size along M and N dimension - using BlockWaves = ck_tile::sequence<4, 1>; // number of waves along M dimension - using BlockTile = ck_tile::sequence<512, 4>; // block size along M and N dimension + using ThreadTile = ck_tile::sequence<1, 4>; // per-thread tile size along M and N + using WaveTile = ck_tile::sequence<64, 4>; // per-wave tile size along M and N dimension + using BlockWaves = ck_tile::sequence<4, 1>; // number of waves per block along M and N dimension + using BlockTile = ck_tile::sequence<512, 4>; // per-block tile size along M and N dimension // Calculate grid size ck_tile::index_t kGridSize = @@ -68,14 +68,14 @@ bool run(const ck_tile::ArgParser& arg_parser) using Shape = ck_tile::TileCopyShape; using Problem = ck_tile::TileCopyProblem; using Policy = ck_tile::TileCopyPolicy; - using Kernel = ck_tile::ElementWiseTileCopyKernel; - // using Kernel = ck_tile::TileCopyKernel; - // using Kernel = ck_tile::TileCopyKernel_LDS; + using Kernel = ck_tile::ElementWiseTileCopyKernel; // operates on element by + // element basis. - // question: Why do we not have a pipeline? - // answer: For basic copy operation, pipeline is not needed. - // we intentionally do not use pipeline for this example and let the kernel be composite of - // Problem and Policy + // We also implement two variations of the copy kernel: + // 1. TileCopyKernel: This is the basic copy kernel that operates on tile by tile basis. + // 2. TileCopyKernel_LDS: This is the copy kernel that operates on tile by tile basis and uses + // the LDS. using Kernel = ck_tile::TileCopyKernel; using Kernel = + // ck_tile::TileCopyKernel_LDS; auto blockSize = Kernel::BlockSize(); diff --git a/example/ck_tile/39_copy/copy_basic.hpp b/tutorial/ck_tile/00_copy_kernel/copy_basic.hpp similarity index 100% rename from example/ck_tile/39_copy/copy_basic.hpp rename to tutorial/ck_tile/00_copy_kernel/copy_basic.hpp diff --git a/example/ck_tile/39_copy/test_tile_example.sh b/tutorial/ck_tile/00_copy_kernel/test_tile_example.sh similarity index 95% rename from example/ck_tile/39_copy/test_tile_example.sh rename to tutorial/ck_tile/00_copy_kernel/test_tile_example.sh index 416338fac4..4ee5fdf15d 100755 --- a/example/ck_tile/39_copy/test_tile_example.sh +++ b/tutorial/ck_tile/00_copy_kernel/test_tile_example.sh @@ -4,7 +4,7 @@ set -euo pipefail -BIN="${BIN:-../../../build/bin/tile_example_copy}" +BIN="${BIN:-../../../build/bin/tile_tutorial_copy_kernel}" WARMUP="${WARMUP:-20}" REPEAT="${REPEAT:-100}" VALIDATE="${VALIDATE:-1}" diff --git a/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md b/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md new file mode 100644 index 0000000000..114fccfd56 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md @@ -0,0 +1,589 @@ +# Block-Level Pipeline: PracticeGemmBlockPipelineAGmemBGmemCreg + +## Overview + +The **Block-Level Pipeline** is where the actual GEMM computation happens for one block tile. It orchestrates: +1. **Data movement** from DRAM → Registers → LDS +2. **GEMM computation** using data in LDS +3. **Iteration** over the K dimension when needed + +This pipeline is called by the host-level pipeline for each block tile that covers a portion of the output matrix C. + +--- + +## Architecture: Problem and Policy + +Like other components in CK Tile, the block pipeline follows the **Problem/Policy** pattern: + +### Problem: `PracticeGemmBlockPipelineProblem` +Contains: +- **Data types**: `ADataType`, `BDataType`, `CDataType`, `AccDataType` +- **Shape information**: `BlockTile` and `WaveTile` dimensions + +### Policy: `PracticeGemmBlockPolicy` +Contains strategies for: +1. **Tile Distribution** (`MakeADramTileDistribution`, `MakeBDramTileDistribution`) + - Defines how 256 threads in a block map to elements of a block tile + - Each thread knows which elements to load/store from DRAM to its registers + - We'll cover tile distribution construction in detail later + +2. **LDS Layout** (`MakeALdsBlockDescriptor`, `MakeBLdsBlockDescriptor`) + - Describes how data is logically organized in Local Data Share (LDS) + - Optimizes for bank conflict avoidance and efficient access patterns + - We'll cover LDS descriptor construction in detail later + +3. **Warp Pipeline** (`GetPracticeWaveGemmPipeline`) + - Returns the warp-level GEMM implementation + +--- + +## Inputs and Outputs + +```cpp +template +CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const +``` + +### Inputs: +- `a_dram_block_window_tmp`: Tile window over A in DRAM (size: MPerBlock × KPerBlock) +- `b_dram_block_window_tmp`: Tile window over B in DRAM (size: NPerBlock × KPerBlock) +- `num_loop`: Number of iterations along K dimension +- `p_smem`: Pointer to shared memory (LDS) + +### Output: +- `c_block_tile`: A `static_distributed_tensor` containing the computed C tile in registers (VGPRs) + +--- + +## Step-by-Step Walkthrough + +### Step 1: Create LDS Tensor Views + +```cpp +// A tile in LDS +ADataType* p_a_lds = static_cast(p_smem); +constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); +auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + +// B tile in LDS (placed after A in shared memory) +BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); +constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); +auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); +``` + +**What's happening:** +- We partition the shared memory (`p_smem`) into two regions: one for A, one for B +- We create **tensor views** over these LDS regions using descriptors from the policy +- `a_lds_block` and `b_lds_block` are logical views over raw LDS memory + +**Memory Layout:** +``` +Shared Memory (LDS): +┌─────────────────────┬─────────────────────┐ +│ A Block Tile │ B Block Tile │ +│ (256×32 fp16) │ (128×32 fp16) │ +└─────────────────────┴─────────────────────┘ +↑ ↑ +p_a_lds p_b_lds +``` + +--- + +### Step 2: Create Tile Windows for Data Movement + +We create **6 tile windows** for different purposes: + +#### 2a. DRAM → Registers (Load from DRAM) + +```cpp +auto a_copy_dram_window = make_tile_window( + a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), // 256×32 + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); // ← Tile distribution! +``` + +**Key Points:** +- `a_copy_dram_window` is a `tile_window_with_static_distribution` +- The **tile distribution** tells each thread which elements to load from DRAM +- This window will **slide along the K dimension** in the loop + +#### 2b. Registers → LDS (Store to LDS) + +```cpp +auto a_copy_lds_window = make_tile_window( + a_lds_block, + make_tuple(number{}, number{}), // 256×32 + {0, 0}, // Origin at (0, 0) in LDS + a_copy_dram_window.get_tile_distribution()); // ← Same distribution as DRAM! +``` + +**Key Points:** +- Uses the **same tile distribution** as `a_copy_dram_window` +- This ensures each thread stores to LDS in the same pattern it loaded from DRAM +- Origin is always `{0, 0}` because LDS is reused for each K iteration + +#### 2c. LDS → Registers (GEMM Input) + +```cpp +auto a_lds_gemm_window = make_tile_window( + a_lds_block, + make_tuple(number{}, number{}), + {0, 0}); // No tile distribution! +``` + +**Key Points:** +- This is a `tile_window_with_static_lengths` (no explicit distribution) +- Used as input to the warp-level GEMM +- The warp GEMM will handle its own thread mapping internally + +**Similar windows are created for B:** +- `b_copy_dram_window`: Load B from DRAM +- `b_copy_lds_window`: Store B to LDS +- `b_lds_gemm_window`: Read B from LDS for GEMM + +--- + +### Step 3: Create Distributed Tensors (VGPRs) + +```cpp +using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); +using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + +using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); +using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + +ABlockTile a_block_tile; // Per-thread registers for A +BBlockTile b_block_tile; // Per-thread registers for B +``` + +#### What is `make_static_distributed_tensor`? + +**`make_static_distributed_tensor`** creates a **`static_distributed_tensor`**, which is a compile-time abstraction for **distributed per-thread register storage**. + +**Key Properties:** +1. **Per-thread VGPRs**: Each thread owns a **different slice** of the tile in its registers +2. **Compile-time sized**: Buffer size determined by tile distribution at compile time +3. **Zero-overhead**: All indexing and layout transformations happen at compile time + +**How it works:** + +```cpp +template +struct static_distributed_tensor +{ + using DataType = remove_cvref_t; + using StaticTileDistribution = remove_cvref_t; + + // Calculate per-thread storage size from tile distribution + using ThreadTensorDesc = + remove_cvref_t; + + static constexpr index_t kThreadElementSpaceSize = + ThreadTensorDesc{}.get_element_space_size(); + + // Per-thread register array (VGPRs) + thread_buffer thread_buf_; +}; +``` + +**The tile distribution defines:** +- **Which elements each thread owns** in the tile +- **How many elements** each thread stores (buffer size) +- **How elements are laid out** in each thread's registers + +**Concrete Example for 256×32 tile with 256 threads:** + +``` +Thread 0: a_block_tile.thread_buf_ = [A[0,0], A[0,1], ..., A[0,31]] (32 fp16 values) +Thread 1: a_block_tile.thread_buf_ = [A[1,0], A[1,1], ..., A[1,31]] (32 fp16 values) +Thread 2: a_block_tile.thread_buf_ = [A[2,0], A[2,1], ..., A[2,31]] (32 fp16 values) +... +Thread 255: a_block_tile.thread_buf_ = [A[255,0], A[255,1], ..., A[255,31]] (32 fp16 values) +``` + +**Collectively:** +- All 256 threads together hold the **entire 256×32 tile** (8192 elements) +- Each thread's buffer lives in its **own VGPRs** +- No two threads own the same element + +**Distributed Ownership Analogy:** +Think of a tile as a **jigsaw puzzle**: +- The **tile distribution** is the cutting pattern +- Each **thread** gets one puzzle piece (its slice) +- Each **`static_distributed_tensor`** is a box holding all pieces +- Each thread's **`thread_buf_`** is its individual piece in its own registers + +--- + +### Step 4: The GEMM Loop + +```cpp +// Initialize C accumulator to zero +auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; +tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + +index_t iCounter = num_loop; // Number of K iterations + +while(iCounter > 0) +{ + // 1. Load from DRAM to registers + a_block_tile = load_tile(a_copy_dram_window); // DRAM → VGPRs + b_block_tile = load_tile(b_copy_dram_window); // DRAM → VGPRs + + // 2. Move windows for next iteration + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); // Step by (0, 32) + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // Step by (0, 32) + + // 3. Store from registers to LDS + store_tile(a_copy_lds_window, a_block_tile); // VGPRs → LDS + store_tile(b_copy_lds_window, b_block_tile); // VGPRs → LDS + + // 4. Synchronize threads (ensure all data is in LDS) + block_sync_lds(); + + // 5. Compute GEMM using data in LDS + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + // 6. Synchronize threads (before overwriting LDS in next iteration) + block_sync_lds(); + + iCounter--; +} + +return c_block_tile; // Return accumulated result in registers +``` + +--- + +## Detailed Loop Breakdown + +### Phase 1: Load (DRAM → VGPRs) + +```cpp +a_block_tile = load_tile(a_copy_dram_window); +``` + +**What happens:** +1. Each thread reads **its assigned elements** from DRAM (determined by tile distribution) +2. Data is loaded into **per-thread registers** (VGPRs) +3. Uses **vectorized loads** for efficiency (e.g., loading 8 fp16 values at once) + +**Example for Thread 0:** +``` +Thread 0 loads: + A[0,0:7] (8 fp16 values, one vector load) + A[1,0:7] (8 fp16 values, one vector load) + ... +``` + +### Phase 2: Move Windows + +```cpp +constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); +move_tile_window(a_copy_dram_window, a_dram_tile_window_step); +``` + +**What happens:** +- The tile window **slides along the K dimension** by `KPerBlock` (32 in our example) +- This prepares for the next K iteration +- The window origin moves from `(0, 0)` → `(0, 32)` → `(0, 64)` → ... + +**Visualization for Problem Size 512×256×64:** +``` +Matrix A (512×64): +┌─────────────────────────────────────┐ +│ Block 0: rows 0-255 │ +│ ┌──────────┬──────────┐ │ +│ │ K=0:31 │ K=32:63 │ │ ← Window slides right +│ │ Iter 0 │ Iter 1 │ │ +│ └──────────┴──────────┘ │ +└─────────────────────────────────────┘ +``` + +### Phase 3: Store (VGPRs → LDS) + +```cpp +store_tile(a_copy_lds_window, a_block_tile); +``` + +**What happens:** +1. Each thread writes **its elements** from registers to LDS +2. Uses the **same distribution** as the DRAM load +3. Data is now in **shared memory**, accessible to all threads in the block + +**Why this step?** +- GEMM computation needs **all threads** to access **all data** +- Registers are per-thread; LDS is shared across the block +- LDS acts as a "staging area" for collaborative computation + +### Phase 4: Synchronize + +```cpp +block_sync_lds(); +``` + +**What happens:** +- All threads in the block **wait** until everyone has finished storing to LDS +- Ensures no thread starts reading from LDS before all writes are complete +- Critical for correctness! + +### Phase 5: GEMM Computation + +```cpp +block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); +``` + +**What happens:** +1. The warp-level GEMM reads data from LDS +2. Performs matrix multiplication using MFMA instructions +3. Accumulates results into `c_block_tile` (in registers) + +**Note:** `c_block_tile` stays in registers throughout all K iterations, accumulating results. + +### Phase 6: Synchronize Again + +```cpp +block_sync_lds(); +``` + +**What happens:** +- Ensures all threads have finished reading from LDS +- Safe to overwrite LDS in the next iteration + +--- + +## Memory Flow Diagram + +``` +Iteration 0 (K=0:31): +┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐ +│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │ +│ A[0:255,│ │ (per- │ │ A_block │ +│ 0:31] │ │ thread) │ │ │ +└─────────┘ └──────────┘ └─────────┘ + │ + │ block_gemm + ↓ + ┌──────────┐ + │ c_block_ │ + │ tile │ + │ (VGPRs) │ + └──────────┘ + +Iteration 1 (K=32:63): +┌─────────┐ load_tile ┌──────────┐ store_tile ┌─────────┐ +│ DRAM │ ────────────> │ VGPRs │ ─────────────> │ LDS │ +│ A[0:255,│ │ (per- │ │ A_block │ +│ 32:63] │ │ thread) │ │ (reused)│ +└─────────┘ └──────────┘ └─────────┘ + │ + │ block_gemm + ↓ + ┌──────────┐ + │ c_block_ │ + │ tile │ + │ (accum.) │ + └──────────┘ +``` + +--- + +## Example: Problem Size 512×256×64 + +### Block 0 Computation + +**Input:** +- `a_dram_block_window_tmp`: Covers A[0:255, 0:31] initially +- `b_dram_block_window_tmp`: Covers B[0:127, 0:31] initially (B is transposed) +- `num_loop`: 2 (since K=64, KPerBlock=32) + +**Iteration 0:** +1. Load A[0:255, 0:31] and B[0:127, 0:31] from DRAM to VGPRs +2. Move windows: A → [0:255, 32:63], B → [0:127, 32:63] +3. Store to LDS +4. Compute: `C[0:255, 0:127] += A[0:255, 0:31] × B[0:127, 0:31]^T` + +**Iteration 1:** +1. Load A[0:255, 32:63] and B[0:127, 32:63] from DRAM to VGPRs +2. Move windows: A → [0:255, 64:95], B → [0:127, 64:95] (out of bounds, but loop ends) +3. Store to LDS +4. Compute: `C[0:255, 0:127] += A[0:255, 32:63] × B[0:127, 32:63]^T` + +**Output:** +- `c_block_tile`: Contains C[0:255, 0:127] in distributed registers + +--- + +## Key Concepts Summary + +### 1. Tile Distribution +- **Maps threads to data elements** for load/store operations +- Each thread knows exactly which elements it's responsible for +- Enables **parallel, vectorized** memory access +- **Same distribution** used for DRAM load and LDS store + +### 2. Static Distributed Tensor +- **Per-thread register storage** (VGPRs) +- Each thread owns a **different slice** of the tile +- **Compile-time sized** for zero-overhead abstraction +- Used for: `a_block_tile`, `b_block_tile`, `c_block_tile` + +### 3. Tile Window Movement +- Windows **slide** over larger tensors +- Enables iteration over the K dimension +- `move_tile_window(window, step)` updates the origin + +### 4. LDS as Staging Area +- **Shared memory** accessible to all threads in a block +- Required because GEMM needs all threads to access all data +- **Reused** across K iterations (same LDS buffer) + +### 5. Synchronization +- `block_sync_lds()` ensures memory consistency +- **Before GEMM**: All stores to LDS are complete +- **After GEMM**: All reads from LDS are complete + +--- + +## Deep Dive: `static_distributed_tensor` Mechanics + +### How Tile Distribution Creates Per-Thread Storage + +When you call: +```cpp +using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); +ABlockTile a_block_tile; +``` + +**Step 1: Extract Thread Tensor Descriptor** + +The tile distribution contains a `ys_to_d_descriptor` that maps: +- **Y dimensions** (logical tile coordinates, e.g., M, K) +- **D dimension** (per-thread register index, linearized) + +```cpp +using ThreadTensorDesc = + decltype(StaticTileDistribution{}.get_ys_to_d_descriptor()); +``` + +**Step 2: Calculate Per-Thread Buffer Size** + +```cpp +static constexpr index_t kThreadElementSpaceSize = + ThreadTensorDesc{}.get_element_space_size(); + +static constexpr index_t get_thread_buffer_size() +{ + return kThreadElementSpaceSize / PackedSize; +} +``` + +**Example:** +- 256×32 tile distributed across 256 threads +- Each thread owns 32 elements (one row) +- `thread_buffer_size = 32` (for PackedSize=1) + +**Step 3: Allocate Thread Buffer** + +```cpp +thread_buffer thread_buf_; +``` + +This is essentially: +```cpp +fp16_t data[32]; // Per-thread register array (VGPRs) +``` + +### Usage in Load/Store Operations + +**Load from DRAM:** +```cpp +a_block_tile = load_tile(a_copy_dram_window); +``` + +What happens internally: +1. Each thread queries the tile distribution: "Which elements do I own?" +2. Thread 0 learns it owns A[0,0:31] +3. Thread 0 loads those elements from DRAM into `a_block_tile.thread_buf_[0:31]` +4. All 256 threads do this **in parallel** + +**Store to LDS:** +```cpp +store_tile(a_copy_lds_window, a_block_tile); +``` + +What happens internally: +1. Each thread reads from its `a_block_tile.thread_buf_` +2. Thread 0 writes A[0,0:31] from its registers to LDS +3. All 256 threads do this **in parallel** +4. After `block_sync_lds()`, the entire tile is in shared LDS + +### Distributed Indexing + +The `static_distributed_tensor` supports compile-time indexing: + +```cpp +// Access using distributed indices +auto value = a_block_tile(tile_distributed_index{}); +``` + +Internally: +1. Convert distributed index → Y index (logical tile coordinates) +2. Calculate buffer offset using `ThreadTensorDesc` +3. Access `thread_buf_[offset]` + +All of this happens **at compile time** with zero runtime overhead! + +### Why This Design? + +**Benefits:** +1. **Parallel Memory Access**: All threads load/store simultaneously +2. **Vectorization**: Each thread can use vector loads (e.g., 8×fp16 at once) +3. **Zero Overhead**: All indexing resolved at compile time +4. **Type Safety**: Distribution mismatch caught at compile time +5. **Register Pressure**: Compiler knows exact VGPR usage + +**Trade-offs:** +- Requires compile-time tile sizes +- Distribution must be static +- More complex type system + +### Memory Hierarchy Summary + +``` +┌─────────────────────────────────────────────────────────────┐ +│ DRAM (Global Memory) │ +│ Full matrices A, B, C │ +└─────────────────────────────────────────────────────────────┘ + │ + │ load_tile (parallel, vectorized) + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ VGPRs (Per-Thread Registers) │ +│ Thread 0: a_block_tile.thread_buf_ = [A[0,0:31]] │ +│ Thread 1: a_block_tile.thread_buf_ = [A[1,0:31]] │ +│ ... │ +│ Thread 255: a_block_tile.thread_buf_ = [A[255,0:31]] │ +│ │ +│ ← static_distributed_tensor manages this distribution │ +└─────────────────────────────────────────────────────────────┘ + │ + │ store_tile (parallel, vectorized) + ↓ +┌─────────────────────────────────────────────────────────────┐ +│ LDS (Shared Memory) │ +│ Entire block tile (256×32) │ +│ Accessible to all threads in block │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Key Insight:** +`static_distributed_tensor` is the abstraction that enables efficient, parallel data movement between DRAM and LDS through per-thread VGPRs, with all coordination happening at compile time. + + + diff --git a/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt b/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt new file mode 100644 index 0000000000..e16977921a --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(tile_tutorial_naive_gemm EXCLUDE_FROM_ALL practice_gemm.cpp) + +target_compile_options(tile_tutorial_naive_gemm PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) + +add_dependencies(tutorials tile_tutorial_naive_gemm) \ No newline at end of file diff --git a/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md b/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md new file mode 100644 index 0000000000..43cb01fb36 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md @@ -0,0 +1,618 @@ +# Host-Level Pipeline: Orchestrating Block-Level GEMM + +This document explains the **host-level pipeline** (`PracticeGemmHostPipeline`), which orchestrates the distribution of work across thread blocks and manages the high-level flow of the GEMM computation. + +## Overview + +The host-level pipeline is responsible for: +1. **Calculating tile coverage**: How many tiles are needed to cover matrices A, B, and C +2. **Block-to-tile mapping**: Assigning each thread block to a specific tile +3. **Creating tile windows**: Establishing sliding windows over tensor views +4. **Delegating computation**: Calling the block-level pipeline to perform actual GEMM +5. **Storing results**: Writing computed tiles from registers (VGPRs) back to DRAM + +```cpp +template +struct PracticeGemmHostPipeline +{ + template + CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, + const BDRAMTensorView& b_dram, + CDRAMTensorView& c_dram) const + { + // 1. Calculate problem dimensions and tile coverage + // 2. Map thread block to tile coordinates + // 3. Create tile windows over A and B + // 4. Call block-level pipeline to compute + // 5. Store result to C + } +}; +``` + +--- + +## Step 1: Calculate Problem Dimensions and Tile Coverage + +```cpp +// Size of the entire problem +const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K +const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N +const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K + +// Size of the block tile +const auto MPerBlock = BlockTile::at(number<0>{}); // 256 +const auto NPerBlock = BlockTile::at(number<1>{}); // 128 +const auto KPerBlock = BlockTile::at(number<2>{}); // 32 + +// Number of block tiles needed to cover C matrix +const auto num_tile_n = integer_divide_ceil(N, NPerBlock); // ceil(256/128) = 2 +const auto num_tile_m = integer_divide_ceil(M, MPerBlock); // ceil(512/256) = 2 +``` + +### What's Happening: + +1. **Extract problem dimensions** from tensor descriptors: + - `M = 512`: Rows in A and C + - `N = 256`: Columns in B and C + - `K = 64`: Inner dimension (columns of A, rows of B) + +2. **Get block tile sizes** from the `BlockTile` configuration: + - `MPerBlock = 256`: Each block processes 256 rows + - `NPerBlock = 128`: Each block processes 128 columns + - `KPerBlock = 32`: Each block processes 32 elements in K dimension per iteration + +3. **Calculate tile coverage**: + - `num_tile_m = ceil(M / MPerBlock) = ceil(512/256) = 2` tiles in M direction + - `num_tile_n = ceil(N / NPerBlock) = ceil(256/128) = 2` tiles in N direction + - **Total tiles = 2 × 2 = 4 tiles** → We need **4 thread blocks**! + +### Visual Representation: + +``` +Matrix C (512 × 256): +┌──────────────────────┬──────────────────────┐ +│ Tile (0,0) │ Tile (0,1) │ ← num_tile_n = 2 +│ 256×128 │ 256×128 │ +│ Block 0 │ Block 1 │ +│ │ │ +├──────────────────────┼──────────────────────┤ +│ Tile (1,0) │ Tile (1,1) │ +│ 256×128 │ 256×128 │ +│ Block 2 │ Block 3 │ +│ │ │ +└──────────────────────┴──────────────────────┘ + ↑ + num_tile_m = 2 + +Total blocks needed = 2 × 2 = 4 blocks + +Each block computes one 256×128 tile of the output matrix C. +``` + +### How Blocks Cover Matrices A and B: + +``` +Matrix A (512 × 64): Matrix B (256 × 64): +┌─────────────┬──────┐ ┌─────────────┬──────┐ +│ Block 0,2 │ K │ │ Block 0,1 │ K │ +│ uses rows │ → │ │ uses rows │ → │ +│ 0-255 │ │ │ 0-127 │ │ +├─────────────┼──────┤ ├─────────────┼──────┤ +│ Block 1,3 │ K │ │ Block 2,3 │ K │ +│ uses rows │ → │ │ uses rows │ → │ +│ 256-511 │ │ │ 128-255 │ │ +└─────────────┴──────┘ └─────────────┴──────┘ + 256 rows 64 cols 128 rows 64 cols + +Each block needs to iterate over K dimension (64/32 = 2 iterations) +``` + +--- + +## Step 2: Map Thread Block to Tile Coordinates + +```cpp +// Get block id (0 to total_blocks - 1) +const auto id_block = get_block_id(); + +// Map block id to 2D tile coordinates +const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n); +const auto tile_id = block2tile(id_block); + +const auto tile_id_m = tile_id.at(number<0>{}); // M coordinate +const auto tile_id_n = tile_id.at(number<1>{}); // N coordinate +``` + +### What's Happening: + +Each thread block needs to know **which tile of the output matrix C it should compute**. The `MakeBlock2TileMap` function creates a mapping from linear block ID to 2D tile coordinates. + +### The `MakeBlock2TileMap` Function: + +```cpp +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) +{ + // Create a merge transform: (N0, M0) → linear index + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + // Convert linear block_id back to 2D coordinates + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + // Return (m_idx, n_idx) - note the swap! + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; +} +``` + +### In Our Example (2×2 Grid): + +```cpp +// Block 0: +id_block = 0 +tile_id = block2tile(0) = (0, 0) // Top-left tile +tile_id_m = 0, tile_id_n = 0 + +// Block 1: +id_block = 1 +tile_id = block2tile(1) = (1, 0) // Bottom-left tile +tile_id_m = 1, tile_id_n = 0 + +// Block 2: +id_block = 2 +tile_id = block2tile(2) = (0, 1) // Top-right tile +tile_id_m = 0, tile_id_n = 1 + +// Block 3: +id_block = 3 +tile_id = block2tile(3) = (1, 1) // Bottom-right tile +tile_id_m = 1, tile_id_n = 1 +``` + +**Key Point**: Each of the 4 blocks knows exactly which 256×128 tile of C it's responsible for computing! + +--- + +## Step 3: Calculate Tile Origin and Create Tile Windows + +```cpp +// Calculate the starting position of this tile in the global matrix +const auto tile_origin_m = tile_id_m * MPerBlock; // e.g., Block 1: 1 * 256 = 256 +const auto tile_origin_n = tile_id_n * NPerBlock; // e.g., Block 2: 1 * 128 = 128 + +// Create tile windows over A and B tensor views +const auto a_block_window = make_tile_window( + a_dram, // Tensor view over A + make_tuple(number{}, number{}), // Window size: 256×32 + {tile_origin_m, 0} // Origin: varies by block +); + +const auto b_block_window = make_tile_window( + b_dram, // Tensor view over B + make_tuple(number{}, number{}), // Window size: 128×32 + {tile_origin_n, 0} // Origin: varies by block +); +``` + +### Tile Origins for Each Block: + +```cpp +// Block 0 (Tile 0,0): +tile_origin_m = 0 * 256 = 0 +tile_origin_n = 0 * 128 = 0 +a_block_window origin: (0, 0) → covers A rows 0-255 +b_block_window origin: (0, 0) → covers B rows 0-127 + +// Block 1 (Tile 1,0): +tile_origin_m = 1 * 256 = 256 +tile_origin_n = 0 * 128 = 0 +a_block_window origin: (256, 0) → covers A rows 256-511 +b_block_window origin: (0, 0) → covers B rows 0-127 + +// Block 2 (Tile 0,1): +tile_origin_m = 0 * 256 = 0 +tile_origin_n = 1 * 128 = 128 +a_block_window origin: (0, 0) → covers A rows 0-255 +b_block_window origin: (128, 0) → covers B rows 128-255 + +// Block 3 (Tile 1,1): +tile_origin_m = 1 * 256 = 256 +tile_origin_n = 1 * 128 = 128 +a_block_window origin: (256, 0) → covers A rows 256-511 +b_block_window origin: (128, 0) → covers B rows 128-255 +``` + +### What are Tile Windows? + +A **tile window** is a **sliding window** over a larger tensor view. It: +- Defines a **rectangular region** within the tensor +- Has a **fixed size** (e.g., 256×32 for A) +- Has an **origin** (starting position) +- Can be **moved** to access different regions +### Visual Representation (Block 0 Example): + +``` +Matrix A (512 × 64): Matrix B (256 × 64): +┌─────────────┬─────────────┐ ┌─────────────┬─────────────┐ +│ ┏━━━━━━━━━┓ │ │ │ ┏━━━━━━━━━┓ │ │ +│ ┃ Window ┃ │ │ │ ┃ Window ┃ │ │ +│ ┃ 256×32 ┃ │ │ │ ┃ 128×32 ┃ │ │ +│ ┃ K=0-31 ┃ │ │ │ ┃ K=0-31 ┃ │ │ +│ ┗━━━━━━━━━┛ │ │ │ ┗━━━━━━━━━┛ │ │ +│ │ │ ├─────────────┼─────────────┤ +├─────────────┼─────────────┤ │ │ │ +│ │ │ │ │ │ +│ │ │ │ │ │ +│ │ │ │ │ │ +└─────────────┴─────────────┘ └─────────────┴─────────────┘ + Origin: (0, 0) Origin: (0, 0) + Covers rows 0-255 Covers rows 0-127 + Covers cols 0-31 (first K iteration) Covers cols 0-31 (first K iteration) +``` + +**Note**: The window initially covers K columns 0-31. It will move to cover K columns 32-63 in the next iteration. + +### Tile Window Properties: + +```cpp +// Tile window structure (conceptual): +struct tile_window { + TensorView& tensor_view; // Reference to underlying tensor + Tuple window_lengths; // Size of the window (256, 32) + MultiIndex window_origin; // Starting position (0, 0) + + // Can move the window: + void move(MultiIndex step); // Shift window by step + + // Access data through the window: + auto load(); // Load data from windowed region +}; +``` + + +### Tile Window Movement: Iterating Over K Dimension + +In our example, **K=64** but **KPerBlock=32**, so we need **2 iterations** over the K dimension: + +``` +Matrix A (512 × 64) - Block 0's view: +┌─────────────┬─────────────┐ +│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │ +│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ ← Window slides along K +│ ┃ 256×32 ┃ │ ║ 256×32 ║ │ +│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │ +│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │ +├─────────────┼─────────────┤ +│ │ │ +│ Block 1's │ │ +│ region │ │ +└─────────────┴─────────────┘ + +Matrix B (256 × 64) - Block 0's view: +┌─────────────┬─────────────┐ +│ ┏━━━━━━━━━┓ │ ╔═══════════╗ │ +│ ┃ Iter 0 ┃ │ ║ Iter 1 ║ │ +│ ┃ 128×32 ┃ │ ║ 128×32 ║ │ +│ ┃ K=0-31 ┃ │ ║ K=32-63 ║ │ +│ ┗━━━━━━━━━┛ │ ╚═══════════╝ │ +├─────────────┼─────────────┤ +│ Block 2's │ │ +│ region │ │ +└─────────────┴─────────────┘ +``` + +### How Windows Move (Conceptual - handled by block pipeline): + +```cpp +// Iteration 0: +a_block_window origin: (tile_origin_m, 0) // K columns 0-31 +b_block_window origin: (tile_origin_n, 0) // K columns 0-31 +// Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] + +// Move windows to next K position: +move_tile_window(a_block_window, {0, 32}); +move_tile_window(b_block_window, {0, 32}); + +// Iteration 1: +a_block_window origin: (tile_origin_m, 32) // K columns 32-63 +b_block_window origin: (tile_origin_n, 32) // K columns 32-63 +// Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] + +// Final result: +// C_tile = C_partial_0 + C_partial_1 +``` + +**Key Insight**: The tile windows **slide along the K dimension** to cover the full inner product. Each block accumulates partial results across K iterations to compute its final tile of C. + +--- + +## Step 4: Delegate to Block-Level Pipeline + +```cpp +// Get the block-level pipeline from policy +constexpr auto block_gemm_pipeline = + Policy::template GetPracticeGemmBlockPipeline(); + +// Calculate number of K iterations needed +int num_loops_k = integer_divide_ceil(K, KPerBlock); // ceil(64/32) = 2 + +// Allocate shared memory (LDS) for block-level computation +__shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; + +// Call block-level pipeline to compute C tile +const auto c_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); +``` + +### What's Happening: + +1. **Retrieve block pipeline**: The policy provides the block-level GEMM implementation +2. **Calculate K iterations**: How many times to iterate over the K dimension + - In our example: `K=64, KPerBlock=32` → **2 iterations** + - Each iteration processes 32 elements of the K dimension + - Results are accumulated across iterations + +3. **Allocate shared memory**: + - `__shared__` declares memory shared by all threads in the block + - `GetStaticLDSSize()` returns the required size in bytes + - This memory is used for: + - Staging data from DRAM → LDS + - Cooperative loading by threads + - Fast access during computation + +4. **Execute block pipeline**: + - Takes A and B tile windows as input + - Performs the GEMM computation: `C_tile = A_tile × B_tile` + - Returns result in `c_block_tile` (stored in VGPRs - registers) + +### Memory Hierarchy During Computation: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ DRAM (Global Memory) - Slowest, Largest │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ A matrix │ │ B matrix │ │ C matrix │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ │ +└─────────────────────────────────────────────────────────────┘ + ↓ load ↓ load ↑ store +┌─────────────────────────────────────────────────────────────┐ +│ LDS (Shared Memory) - Fast, Limited Size (~64KB) │ +│ ┌─────────────┐ ┌─────────────┐ │ +│ │ A_tile │ │ B_tile │ ← Staged here │ +│ │ (p_smem) │ │ (p_smem) │ │ +│ └─────────────┘ └─────────────┘ │ +└─────────────────────────────────────────────────────────────┘ + ↓ load ↓ load +┌─────────────────────────────────────────────────────────────┐ +│ VGPRs (Registers) - Fastest, Smallest (~256 regs/thread) │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ c_block_tile (accumulated result) │ │ +│ │ Computation happens here using MFMA instructions │ │ +│ └─────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### Block Pipeline Responsibilities: + +The block pipeline (called here) will: +1. Load A and B tiles from DRAM → LDS (cooperative loading) +2. Distribute work among warps +3. Each warp loads its portion from LDS → VGPRs +4. Perform MFMA operations: `C += A × B` +5. Accumulate results in VGPRs +6. Return final `c_block_tile` in registers + +--- + +## Step 5: Store Results to DRAM + +```cpp +// Create a tile window over C for writing results +auto c_window = make_tile_window( + c_dram, // Tensor view over C + make_tuple(number{}, number{}), // Window size: 256×128 + {tile_origin_m, tile_origin_n} // Origin: varies by block +); + +// Store computed tile from VGPRs to DRAM +store_tile(c_window, c_block_tile); +``` + +### C Window Origins for Each Block: + +```cpp +// Block 0: Writes to top-left tile +c_window origin: (0, 0) → writes to C[0:255, 0:127] + +// Block 1: Writes to bottom-left tile +c_window origin: (256, 0) → writes to C[256:511, 0:127] + +// Block 2: Writes to top-right tile +c_window origin: (0, 128) → writes to C[0:255, 128:255] + +// Block 3: Writes to bottom-right tile +c_window origin: (256, 128) → writes to C[256:511, 128:255] +``` + +### What's Happening: + +1. **Create C tile window**: + - Size: 256×128 (matches our block tile size) + - Origin: Varies by block - each block writes to its assigned region + - This window defines **where** to write the results + +2. **Store tile to DRAM**: + - `c_block_tile`: Computed results in VGPRs (registers) + - `c_window`: Destination window in DRAM + - `store_tile()`: Efficiently writes data from registers → DRAM + +### The `store_tile` Function: + +Recall from our earlier discussion, `store_tile` does: + +```cpp +template +void store_tile(TileWindow& tile_window_tmp, + const DistributedTensor& dstr_tensor) +{ + // 1. Extract tile distribution from distributed tensor + using TileDstr = typename DistributedTensor::TileDistribution; + + // 2. Upgrade simple tile window to one with distribution + auto tile_window = make_tile_window( + tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + TileDstr{} // Add distribution info + ); + + // 3. Store using vectorized writes + tile_window.store(dstr_tensor); +} +``` + +### Memory Flow: + +``` +VGPRs (Registers) DRAM (Global Memory) +┌─────────────────────┐ ┌─────────────────────┐ +│ c_block_tile │ │ C matrix │ +│ ┌───┬───┬───┬───┐ │ │ ┌───────────────┐ │ +│ │W0 │W1 │W2 │W3 │ │ store_tile │ │ │ │ +│ ├───┼───┼───┼───┤ │ ==========> │ │ c_window │ │ +│ │...│...│...│...│ │ vectorized │ │ (256×128) │ │ +│ └───┴───┴───┴───┘ │ │ │ │ │ +│ Distributed across │ │ └───────────────┘ │ +│ threads/warps │ │ Origin: (0, 0) │ +└─────────────────────┘ └─────────────────────┘ + +Each thread writes its portion using vector stores (e.g., float4) +``` + +### Store Optimization: + +The `store_tile` function: +- Uses **vectorized stores** (write multiple elements at once) +- Ensures **coalesced memory access** (adjacent threads write adjacent memory) +- Respects **tile distribution** (each thread knows what data it owns) +- Handles **out-of-bounds** checking (for partial tiles at boundaries) + +--- + +## Complete Flow Visualization + +Let's trace the complete flow for **Block 0** (other blocks follow the same pattern): + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Step 1: Calculate Tile Coverage │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ M=512, N=256, K=64 │ │ +│ │ MPerBlock=256, NPerBlock=128, KPerBlock=32 │ │ +│ │ num_tile_m = ceil(512/256) = 2 │ │ +│ │ num_tile_n = ceil(256/128) = 2 │ │ +│ │ Total blocks needed = 2 × 2 = 4 blocks │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 2: Map Block to Tile (Block 0 example) │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Block ID: 0 │ │ +│ │ Tile coordinates: (0, 0) - top-left tile │ │ +│ │ Tile origin: (0, 0) │ │ +│ │ │ │ +│ │ (Blocks 1,2,3 get different tile coordinates) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 3: Create Tile Windows │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ a_block_window: 256×32 starting at (0,0) over A │ │ +│ │ b_block_window: 128×32 starting at (0,0) over B │ │ +│ │ Windows initially cover K columns 0-31 │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 4: Execute Block Pipeline (2 K iterations) │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Allocate shared memory (LDS) │ │ +│ │ Call block_gemm_pipeline(a_window, b_window, 2, p_smem) │ │ +│ │ │ │ +│ │ K Iteration 0 (K=0-31): │ │ +│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Compute: C_partial_0 = A[:, 0:31] × B[:, 0:31] │ │ +│ │ └─ Move windows: {0, 32} │ │ +│ │ │ │ +│ │ K Iteration 1 (K=32-63): │ │ +│ │ ├─ Load A tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Load B tile: DRAM → LDS → VGPRs │ │ +│ │ ├─ Compute: C_partial_1 = A[:, 32:63] × B[:, 32:63] │ │ +│ │ └─ Accumulate: C_tile = C_partial_0 + C_partial_1 │ │ +│ │ │ │ +│ │ Return c_block_tile in VGPRs (256×128 accumulated result) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Step 5: Store Results │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Create c_window: 256×128 starting at (0,0) over C │ │ +│ │ store_tile(c_window, c_block_tile) │ │ +│ │ └─ Write from VGPRs → DRAM (vectorized stores) │ │ +│ │ │ │ +│ │ Block 0 writes to C[0:255, 0:127] │ │ +│ │ (Other blocks write to their respective regions) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + +All 4 blocks execute in parallel, each computing its assigned 256×128 tile! +``` + +--- + +## Key Concepts Summary + +### 1. **Tile Coverage** +- Determines how many thread blocks are needed +- Each block processes one tile of the output matrix C +- Calculated as `ceil(dimension / tile_size)` + +### 2. **Block-to-Tile Mapping** +- Maps linear block ID to 2D tile coordinates +- Uses column-major ordering for better memory coalescing +- Each block knows which tile it's responsible for + +### 3. **Tile Windows** +- **Sliding windows** over larger tensor views +- Define a rectangular region with fixed size and movable origin +- Provide efficient, structured access to tensor data +- Can be moved to access different regions (e.g., for K iterations) + +### 4. **Memory Hierarchy** +- **DRAM (Global)**: Largest, slowest - stores full matrices +- **LDS (Shared)**: Medium, fast - stages tiles for cooperative access +- **VGPRs (Registers)**: Smallest, fastest - performs computation + +### 5. **Data Flow** +``` +DRAM → Tile Windows → LDS → VGPRs → Computation → VGPRs → DRAM + ↑ ↓ + A, B matrices C matrix +``` + +--- + +## Next Steps + +The host-level pipeline has set up the work and delegated to the block-level pipeline. Next, we'll explore: +- **Block-level pipeline**: How tiles are loaded, distributed to warps, and computed +- **Warp-level pipeline**: How warps perform MFMA operations +- **Memory optimization**: LDS usage, bank conflicts, coalescing + +The host level provides the **orchestration**, while the block and warp levels provide the **execution**! + diff --git a/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md b/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md new file mode 100644 index 0000000000..7cd0d06fc5 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md @@ -0,0 +1,464 @@ +# PracticeGemmKernel: Understanding the Kernel Entry Point + +This document explains the `PracticeGemmKernel` structure, which serves as the **entry point** for our GEMM GPU kernel. We'll dive deep into how raw memory is transformed into structured tensor views. + +## Overview + +The `PracticeGemmKernel` is a templated struct that: +1. Takes raw device memory pointers for matrices A, B, and C +2. Wraps them into **tensor views** - logical, structured views over physical memory +3. Dispatches to the host-level pipeline for computation + +```cpp +template +struct PracticeGemmKernel +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + + CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, + const typename Problem::BDataType* p_b, + typename Problem::CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c) const + { + // Step 1: Create tensor views over raw memory + auto a_dram = make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{}); + + auto b_dram = make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{}); + + const auto c_dram = make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{}); + + // Step 2: Dispatch to host-level pipeline + PracticeGemmHostPipeline{}(a_dram, b_dram, c_dram); + } +}; +``` + +--- + +## What are Tensor Views? + +A **tensor view** is a **logical, structured view over raw physical memory**. It doesn't own or allocate memory—it simply provides a way to interpret and access existing memory as a multi-dimensional tensor. + +### Key Components of a Tensor View: + +1. **Memory Type**: Where the data lives (global/DRAM, LDS/shared, registers) +2. **Raw Pointer**: Points to the actual data in memory +3. **Shape**: Dimensions of the tensor (e.g., M×K for matrix A) +4. **Strides**: How to navigate through memory to access elements +5. **Guaranteed Vector Length**: How many consecutive elements can be loaded in one vector instruction +6. **Guaranteed Vector Stride**: The stride of those vectorizable elements + +--- + +## The Memory Abstraction Hierarchy + +CK Tile uses a three-layer abstraction to go from raw memory to structured tensors: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Layer 3: TENSOR VIEW │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ • Logical multi-dimensional structure │ │ +│ │ • Shape: (M, K) = (256, 32) │ │ +│ │ • Strides: (32, 1) for row-major layout │ │ +│ │ • Provides: operator[], coordinate-based access │ │ +│ │ • Knows: How to map (i,j) → linear offset │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ ↓ wraps │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Layer 2: BUFFER VIEW │ │ +│ │ ┌─────────────────────────────────────────────────────┐ │ │ +│ │ │ • Linear view of memory │ │ │ +│ │ │ • Pointer: p_data_ → device memory │ │ │ +│ │ │ • Size: Total number of elements │ │ │ +│ │ │ • Address space: global/LDS/generic │ │ │ +│ │ │ • Provides: Vectorized loads/stores, bounds checking│ │ │ +│ │ └─────────────────────────────────────────────────────┘ │ │ +│ └─────────────────────────────────────────────────────────┘ │ +│ ↓ wraps │ +│ ┌─────────────────────────────────────────────────────────┐ │ +│ │ Layer 1: RAW PHYSICAL MEMORY │ │ +│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ │ │ +│ │ │ 0.0 │ 1.0 │ 2.0 │ 3.0 │ 4.0 │ 5.0 │ 6.0 │ 7.0 │ ... │ │ │ +│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘ │ │ +│ │ ↑ │ │ +│ │ p_a (raw pointer from hipMalloc) │ │ +│ └─────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +--- + +## Deep Dive: `make_naive_tensor_view` + +Let's break down the function call for matrix A: + +```cpp +auto a_dram = make_naive_tensor_view( + p_a, // Raw pointer to device memory + make_tuple(M, K), // Shape: (256, 32) + make_tuple(stride_a, 1), // Strides: (32, 1) - row-major + number<8>{}, // Guaranteed vector length + number<1>{} // Guaranteed vector stride +); +``` + +### Function Signature: + +```cpp +template +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_view(DataType* __restrict__ p, + const tuple& lengths, + const tuple& strides, + number = number<-1>{}, + number = number<-1>{}) +{ + // Step 1: Create tensor descriptor (shape + stride information) + auto desc = make_naive_tensor_descriptor(lengths, + strides, + number{}, + number{}); + + // Step 2: Create buffer view (pointer + size + address space) + auto buffer_view = + make_buffer_view(p, desc.get_element_space_size()); + + // Step 3: Combine into tensor view + return tensor_view{buffer_view, desc}; +} +``` + +--- + +## Parameter Breakdown + +### 1. **Template Parameter: `address_space_enum::global`** + +Specifies where the memory lives: +- `global`: GPU global memory (DRAM) - slowest but largest +- `lds`: Local Data Share (shared memory) - fast, limited size +- `generic`: Generic address space +- `vgpr`: Vector General Purpose Registers - fastest, smallest + +In our case, `global` means the data is in GPU DRAM. + +### 2. **`p_a` - Raw Pointer** + +The raw device memory pointer returned by `hipMalloc`. Points to the start of the matrix data. + +### 3. **`make_tuple(M, K)` - Shape/Lengths** + +Defines the logical dimensions of the tensor: +- For matrix A: `(256, 32)` means 256 rows, 32 columns +- This is the **logical view**, independent of how data is physically laid out + +### 4. **`make_tuple(stride_a, 1)` - Strides** + +Defines how to navigate through memory: +- **Stride for dimension 0 (rows)**: `stride_a = K = 32` + - To move to the next row, skip 32 elements +- **Stride for dimension 1 (columns)**: `1` + - To move to the next column, skip 1 element + +**Row-major layout example:** +``` +Memory: [a₀₀, a₀₁, a₀₂, ..., a₀₃₁, a₁₀, a₁₁, a₁₂, ..., a₁₃₁, ...] + ↑ ↑ + Row 0 starts here Row 1 starts here (offset = 32) + +To access element A[i][j]: + offset = i * stride_a + j * 1 + = i * 32 + j +``` + +### 5. **`number<8>{}` - Guaranteed Last Dimension Vector Length** + +This tells the tensor view: **"The last dimension (K) is guaranteed to have at least 8 consecutive elements that can be loaded together in a single vector instruction."** + +#### Why is this important? + +Modern GPUs can load multiple elements in one instruction (vectorized loads): +- `float4`: Load 4 floats at once +- `float8`: Load 8 floats at once (if supported) + +By specifying `number<8>{}`, we're telling the system: +- "You can safely use vector loads of up to 8 elements" +- "The memory alignment and layout support this" + +**Example:** +```cpp +// Without vectorization (slow): +for (int j = 0; j < 8; j++) { + data[j] = memory[offset + j]; // 8 separate loads +} + +// With vectorization (fast): +float8 vec = *reinterpret_cast(&memory[offset]); // 1 load! +``` + +### 6. **`number<1>{}` - Guaranteed Last Dimension Vector Stride** + +This specifies the **stride between consecutive vectorizable elements** in the last dimension. + +- `number<1>{}` means: "Consecutive elements in the last dimension are contiguous in memory (stride = 1)" +- This confirms that elements `A[i][0], A[i][1], A[i][2], ..., A[i][7]` are stored consecutively + +**Why does this matter?** + +For efficient vectorized loads, elements must be: +1. **Contiguous** (stride = 1) ✓ +2. **Aligned** properly in memory +3. **Within the same cache line** (ideally) + +If the stride were `2`, it would mean: +``` +A[i][0] is at offset 0 +A[i][1] is at offset 2 (not 1!) +A[i][2] is at offset 4 +``` +This would prevent efficient vectorization. + +--- + +## What is a Buffer View? + +A **buffer view** is the middle layer between raw memory and tensor view. It provides: + +### Core Responsibilities: + +1. **Memory Management** + - Holds the raw pointer: `T* p_data_` + - Tracks buffer size: `BufferSizeType buffer_size_` + - Knows the address space: `global`, `lds`, etc. + +2. **Vectorized Access** + ```cpp + template + CK_TILE_DEVICE VectorType get(index_t offset); + ``` + - Provides efficient vector loads/stores + - Handles alignment requirements + +3. **Bounds Checking** (optional) + ```cpp + template + CK_TILE_DEVICE auto get(index_t i, index_t linear_offset); + ``` + - Can optionally check if access is within bounds + - Returns invalid value (default 0) for out-of-bounds access + +4. **Address Space Awareness** + - Uses different load/store instructions based on address space + - Global memory: `global_load`, `global_store` + - LDS: `ds_read`, `ds_write` + +### Buffer View Structure: + +```cpp +template +struct buffer_view +{ + T* p_data_; // Raw pointer + BufferSizeType buffer_size_; // Total elements + remove_cvref_t invalid_element_value_; // Value for OOB access + + // Access operators + const T& operator[](index_t i) const; // Read + T& operator()(index_t i); // Write + + // Vectorized access + template + VectorType get(index_t offset); +}; +``` + +--- + +## Visual Example: Matrix A Memory Layout + +Let's visualize how matrix A (256×32, fp16) is organized: + +### Raw Physical Memory (Linear): +``` +GPU DRAM Address Space: +┌─────────────────────────────────────────────────────────────────┐ +│ Byte 0 │ +│ ↓ │ +│ [a₀₀][a₀₁][a₀₂]...[a₀₃₁][a₁₀][a₁₁][a₁₂]...[a₁₃₁][a₂₀]... │ +│ ↑ ↑ │ +│ Row 0 (32 elements) Row 1 (32 elements) │ +│ │ +│ Total: 256 rows × 32 cols × 2 bytes/element = 16,384 bytes │ +└─────────────────────────────────────────────────────────────────┘ + ↑ + p_a (raw pointer) +``` + +### Buffer View Layer: +``` +buffer_view +┌─────────────────────────────────────────────────────────────────┐ +│ p_data_ = p_a │ +│ buffer_size_ = 256 × 32 = 8,192 elements │ +│ address_space = global (DRAM) │ +│ │ +│ Provides: │ +│ • Linear indexing: buffer_view[i] → element at offset i │ +│ • Vectorized loads: get(offset) → load 4 fp16s at once│ +│ • Bounds checking: is offset < buffer_size_? │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### Tensor View Layer: +``` +tensor_view +┌─────────────────────────────────────────────────────────────────┐ +│ Shape: (256, 32) │ +│ Strides: (32, 1) │ +│ Guaranteed vector length: 8 │ +│ Guaranteed vector stride: 1 │ +│ │ +│ Logical 2D View: │ +│ Col: 0 1 2 ... 31 │ +│ Row 0: [a₀₀][a₀₁][a₀₂] ... [a₀₃₁] ← Can vector load 8 at once│ +│ Row 1: [a₁₀][a₁₁][a₁₂] ... [a₁₃₁] │ +│ Row 2: [a₂₀][a₂₁][a₂₂] ... [a₂₃₁] │ +│ ... │ +│ Row 255: [a₂₅₅,₀] ... [a₂₅₅,₃₁] │ +│ │ +│ Provides: │ +│ • Multi-dimensional indexing: A[i][j] │ +│ • Coordinate transformation: (i,j) → linear offset = i*32 + j │ +│ • Tile window creation: Extract sub-tensors │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Complete Flow: Raw Memory → Tensor View + +Let's trace the complete transformation for matrix A: + +### Step 1: Kernel Launch (Host Side) +```cpp +// On host: Allocate device memory +hipMalloc(&p_a, M * K * sizeof(fp16_t)); // Returns raw pointer + +// Launch kernel +kernel<<>>(p_a, p_b, p_c, M, N, K, ...); +``` + +### Step 2: Inside Kernel (Device Side) +```cpp +// Receive raw pointer +const fp16_t* p_a; // Points to GPU DRAM + +// Step 2a: Create tensor descriptor +auto desc = make_naive_tensor_descriptor( + make_tuple(256, 32), // Shape + make_tuple(32, 1), // Strides + number<8>{}, // Vector length + number<1>{} // Vector stride +); +// desc now knows: "This is a 256×32 tensor, row-major, vectorizable by 8" + +// Step 2b: Create buffer view +auto buffer_view = make_buffer_view( + p_a, // Raw pointer + 256 * 32 // Total elements +); +// buffer_view now wraps p_a with size and address space info + +// Step 2c: Create tensor view +auto a_dram = tensor_view{buffer_view, desc}; +// a_dram now provides structured, multi-dimensional access to p_a +``` + +### Step 3: Using the Tensor View +```cpp +// Access element A[i][j] +auto value = a_dram[make_tuple(i, j)]; + +// Create a tile window (sub-tensor) +auto tile = make_tile_window( + a_dram, + make_tuple(16, 16), // 16×16 tile + make_tuple(0, 0) // Starting at origin +); + +// Load tile into registers with vectorization +auto tile_data = load_tile(tile); // Uses vector loads internally! +``` + +--- + +## Why This Abstraction? + +### Benefits: + +1. **Type Safety**: Can't accidentally access wrong dimensions +2. **Performance**: Compiler knows about vectorization opportunities +3. **Flexibility**: Same code works for different memory spaces (DRAM, LDS, registers) +4. **Maintainability**: Logical structure separate from physical layout +5. **Optimization**: Guaranteed vector properties enable aggressive optimizations + +### Example: Without Tensor Views (Manual Indexing) +```cpp +// Ugly, error-prone, hard to optimize: +for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + float val = p_a[tile_offset_i * stride_a + tile_offset_j + i * stride_a + j]; + // Hope the compiler vectorizes this? 🤞 + } +} +``` + +### Example: With Tensor Views (Clean, Optimized) +```cpp +// Clean, safe, automatically vectorized: +auto tile = make_tile_window(a_dram, make_tuple(16, 16), origin); +auto tile_data = load_tile(tile); // Vectorized loads guaranteed! +``` + +--- + +## Summary + +The `PracticeGemmKernel` entry point transforms raw GPU memory into structured, multi-dimensional tensors through a three-layer abstraction: + +1. **Raw Memory**: Linear array of bytes in GPU DRAM +2. **Buffer View**: Adds size, address space, and vectorized access +3. **Tensor View**: Adds shape, strides, and multi-dimensional indexing + +This abstraction enables: +- ✅ Clean, readable code +- ✅ Type-safe multi-dimensional access +- ✅ Automatic vectorization +- ✅ Flexible memory space handling +- ✅ Efficient tile-based computation + +The tensor views created here are then passed to the host-level pipeline, which orchestrates the block-level GEMM computation! + diff --git a/tutorial/ck_tile/01_naive_gemm/README.md b/tutorial/ck_tile/01_naive_gemm/README.md new file mode 100644 index 0000000000..f2caf7d993 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/README.md @@ -0,0 +1,150 @@ +# CK Tile Practice GEMM Example + +This is a practice implementation of a GEMM (General Matrix Multiplication) kernel using the CK Tile API. It demonstrates the fundamental concepts of GPU kernel development using CK Tile's hierarchical tile system. + +## CK Tile API Structure + +In the composable_kernel library's ck_tile API, **A Kernel is composed of a Problem, a Policy and an Epilogue**: + +1. **Problem** describes the shape, data type, data layout, precision of our GEMM matrices +2. **Policy** describes how the data in the matrix (or tile) is mapped to the threads +3. **Epilogue** describes additional computation work performed after the gemm computations (this example does not have an epilogue) + +## Overview + +This example implements a complete GEMM kernel `C = A × B` using the CK Tile framework, showcasing: + +- **Problem Setup** - Setting up the problem (input/output shapes, data types, mathematical operations), composing a kernel (pipeline, policy, epilogue), kernel launch +- **Block-level Pipelining** - creating tensor views, dispatching to block-level GEMM +- **Block-level GEMM Computation** - Block tiles, tile window creation, loading/storing to DRAM and Register memory +- **Warp-level GEMM Computation** - Warp tiles, MFMA level computation + +## Problem Setup and Data Flow + +### Problem Size Configuration +We set the problem size using the M, N and K variables: +```cpp +ck_tile::index_t M = 1024; // Number of rows in A and C +ck_tile::index_t N = 512; // Number of columns in B and C +ck_tile::index_t K = 256; // Number of columns in A, rows in B +``` + +### Host Matrix Creation +Three host matrices A (M×K), B (N×K) and C (M×N) are created, initialized on the CPU and copied over to the GPU global/DRAM memory: +```cpp +// Host tensors with proper strides +ck_tile::HostTensor a_host(a_lengths, a_strides); // M × K +ck_tile::HostTensor b_host(b_lengths, b_strides); // N × K +ck_tile::HostTensor c_host(c_lengths, c_strides); // M × N + +// Initialize with random data +ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); +ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); + +// Allocate device memory and transfer data +ck_tile::DeviceMem a_device(a_host); +a_device.ToDevice(a_host.data()); +``` + +### PracticeGemmShape Configuration +A PracticeGemmShape struct holds the dimension of each BlockTile and WaveTile: + +```cpp +using BlockTile = ck_tile::sequence<256, 128, 32>; // M, N, K per block +using WaveTile = ck_tile::sequence<16, 16, 16>; // M, N, K per wave +``` +- A BlockTile of size MxK (256x32) on A matrix and NxK (128x32) on B matrix. A WaveTile of size MxN (16x16) on C matrix. + + +- BlockTiles iterate in K dimension to fetch data required for computing region of C covered by C's block tile. +- BlockTiles are further subdivided into WarpTiles. +- WarpTiles over A and B similarly work together to calculate the WarpTile of C. + +### Problem and Policy Composition +```cpp +// A Problem is composed from Shape and info about the data +using PracticeGemmHostProblem = ck_tile:: + PracticeGemmHostProblem; + +// A Policy is created describing data-to-thread mapping +using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; + +// A Kernel is then composed of Problem and Policy +using gemm_kernel = ck_tile::PracticeGemmKernel; +``` + +### Kernel Launch +`ck_tile::launch_kernel()` is used to launch the kernel on device. It calls the `operator()` function of `PracticeGemmKernel{}`: +```cpp +float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel( + gemm_kernel{}, // Kernel composed of Problem + Policy + kGridSize, // Grid dimensions + kBlockSize, // Block dimensions + 0, // Dynamic shared memory + // Kernel arguments: device buffers and problem dimensions + a_device.GetDeviceBuffer(), b_device.GetDeviceBuffer(), c_device.GetDeviceBuffer(), + M, N, K, stride_a, stride_b, stride_c)); +``` + +### Result Verification +The results from the kernel are compared with results from CPU based computation function: +```cpp +// CPU reference implementation +ck_tile::HostTensor c_host_ref(c_lengths, c_strides); +reference_basic_gemm(a_host, b_host, c_host_ref); + +// Device results +ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + +// Verify correctness +bool pass = ck_tile::check_err(c_host_dev, c_host_ref); +``` + +### Runtime Flow + +The main program (`practice_gemm.cpp`) is the entry point for the runtime flow: + +```cpp +int main() +{ + // 1. Define data types and problem sizes + using ADataType = ck_tile::half_t; + ck_tile::index_t M = 2048, N = 1024, K = 512; + + // 2. Create host tensors and initialize + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); + + // 3. Allocate device memory and transfer data + ck_tile::DeviceMem a_device(a_host); + + // 4. Configure tile shapes + using BlockTile = ck_tile::sequence<256, 128, 32>; + using WaveTile = ck_tile::sequence<16, 16, 16>; + + // 5. Launch kernel + using gemm_kernel = ck_tile::PracticeGemmKernel; + float ave_time = ck_tile::launch_kernel(/*...*/); + + // 6. Verify results + bool pass = verify_results(a_host, b_host, c_host); + + // 7. Print performance metrics + print_performance_metrics(ave_time, M, N, K); +} +``` + +## Building and Running + +```bash +# From composable_kernel root directory +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ +make tile_example_practice_gemm -j + +# Run with sample sizes +./bin/tile_example_practice_gemm +``` +This example serves as a foundation for understanding more complex GEMM implementations and optimization strategies in the CK Tile framework. diff --git a/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md b/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md new file mode 100644 index 0000000000..d0b8400b9c --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md @@ -0,0 +1,506 @@ +# Practice GEMM: Step-by-Step Code Walkthrough + +This document provides a detailed walkthrough of `practice_gemm.cpp`, explaining each step of implementing a GEMM (General Matrix Multiplication) kernel using the CK Tile API. + +## Overview + +We'll implement `C = A × B` where: +- `A` is an `M × K` matrix +- `B` is an `N × K` matrix (note: transposed layout) +- `C` is an `M × N` matrix + +The implementation uses a hierarchical tiling strategy with two levels: +1. **Block Tiles**: Processed by thread blocks +2. **Wave Tiles**: Processed by warps (wavefronts) within blocks + +--- + +## Step 1: Define Data Types + +```cpp +using ADataType = ck_tile::half_t; +using BDataType = ck_tile::half_t; +using CDataType = float; +using AccDataType = float; +``` + +**What's happening:** +- We use `half_t` (FP16) for input matrices A and B. +- We use `float` (FP32) for output matrix C and accumulation for numerical accuracy +- In typical CK examples, this information is part of a `GemmConfig` struct, but here we define it directly for simplicity +--- + +## Step 2: Define Problem Size + +```cpp +ck_tile::index_t M = 512; +ck_tile::index_t N = 256; +ck_tile::index_t K = 64; +ck_tile::index_t verification = 1; + +ck_tile::index_t stride_a = K; +ck_tile::index_t stride_b = K; +ck_tile::index_t stride_c = N; +``` + +**What's happening:** +- `M = 512`: Number of rows in A and C +- `N = 256`: Number of columns in B and C +- `K = 64`: Inner dimension (columns of A, rows of B) +- Strides define memory layout (row-major for A and C, transposed for B) + +**Memory Layout:** +``` +Matrix A (M×K): Matrix B (N×K): Matrix C (M×N): +[512 rows] [256 rows] [512 rows] +[64 cols] [64 cols] [256 cols] +stride = K stride = K stride = N +``` + +--- + +## Step 3: Create Host Tensors + +```cpp +auto a_lengths = std::array{M, K}; +auto b_lengths = std::array{N, K}; +auto c_lengths = std::array{M, N}; + +auto a_strides = std::array{stride_a, 1}; +auto b_strides = std::array{stride_b, 1}; +auto c_strides = std::array{stride_c, 1}; + +ck_tile::HostTensor a_host(a_lengths, a_strides); +ck_tile::HostTensor b_host(b_lengths, b_strides); +ck_tile::HostTensor c_host(c_lengths, c_strides); +``` + +**What's happening:** +- We create three tensors on the host (CPU) memory +- Each tensor is defined by its shape (`lengths`) and memory layout (`strides`) +- `HostTensor` is a CK Tile utility class that manages CPU memory + +**Stride explanation:** +- For A: `stride_a = K` means moving to the next row requires skipping K elements +- For B: `stride_b = K` means B is stored in transposed format +- For C: `stride_c = N` means row-major layout + +--- + +## Step 4: Initialize Tensors with Random Data + +```cpp +ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); +ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); +c_host.SetZero(); +``` + +**What's happening:** +- A and B are filled with random values in the range [-5.0, 5.0] +- C is initialized to zero (will store the output) + +**Optional: Print Tensor Contents** +```cpp +// Commented out in the code, but available for debugging: +// a_host.print_first_n(10); // Print first 10 elements of A +``` + +The `print_first_n()` helper function can display tensor contents for debugging purposes. + +--- + +## Step 5: Allocate Device Memory and Transfer Data + +```cpp +ck_tile::DeviceMem a_device(a_host); +ck_tile::DeviceMem b_device(b_host); +ck_tile::DeviceMem c_device(c_host); +``` + +**What's happening:** +- `DeviceMem` allocates GPU memory matching the size of host tensors +- The constructor **automatically transfers data from host to device** +- This is a convenience wrapper around `hipMalloc` and `hipMemcpy` + +**Memory Flow:** +``` +CPU (Host) GPU (Device) +┌─────────┐ ┌─────────┐ +│ a_host │ ────────> │a_device │ +│ b_host │ ────────> │b_device │ +│ c_host │ ────────> │c_device │ +└─────────┘ └─────────┘ +``` + +--- + +## Step 6: Configure Hierarchical Tiling + +```cpp +using BlockTile = ck_tile::sequence<256, 128, 32>; +using WaveTile = ck_tile::sequence<16, 16, 16>; +``` + +**What's happening:** +- We define a two-level tiling hierarchy for the GEMM computation + +### Block Tile (256 × 128 × 32) +- **256**: M dimension per block (rows of A and C) +- **128**: N dimension per block (columns of B and C) +- **32**: K dimension per block (inner dimension) +- Each block tile is processed by one **thread block** (256 threads) + +### Wave Tile (16 × 16 × 16) +- **16 × 16**: Output tile dimensions (M × N) per warp iteration +- **16**: K dimension per warp iteration +- Each wave tile is processed by one **warp** (64 threads on AMD GPUs) + +**Important:** The WaveTile (16×16×16) is NOT the same as the MFMA instruction size (32×32×8). The WaveTile represents the work done per warp per iteration, while MFMA is the underlying hardware instruction. Multiple MFMA operations may be needed to compute one wave tile + +**Important Note:** +In this example, the problem size (256 × 128 × 32) is **identical** to the block tile size, so only **one thread block** is needed to compute the entire problem. + +### Tiling Visualization: + +#### Matrix A (M × K = 256 × 32): +``` +┌─────────────────────────────────────┐ +│ One Block Tile (256 × 32) │ +│ ┌────┬────┐ │ +│ │16×│16× │ ← Wave tiles (16×16) │ +│ │ 16│ 16 │ in M×K space │ +│ ├────┼────┤ │ +│ │ │ │ │ +│ ├────┼────┤ │ +│ │ .. │ .. │ 16 tiles in M │ +│ ├────┼────┤ 2 tiles in K │ +│ │ │ │ │ +│ └────┴────┘ │ +│ │ +└─────────────────────────────────────┘ +``` + +#### Matrix B (N × K = 128 × 32): +``` +┌──────────────────────────────┐ +│ One Block Tile (128 × 32) │ +│ ┌────┬────┐ │ +│ │16×│16× │ ← Wave tiles │ +│ │ 16│ 16 │ (16×16) │ +│ ├────┼────┤ │ +│ │ │ │ │ +│ ├────┼────┤ 8 tiles in N │ +│ │ .. │ .. │ 2 tiles in K │ +│ ├────┼────┤ │ +│ │ │ │ │ +│ └────┴────┘ │ +└──────────────────────────────┘ +``` + +#### Matrix C (M × N = 256 × 128) - Output: +``` +┌─────────────────────────────────────────────────┐ +│ One Block Tile (256 × 128) │ +│ │ +│ ┌────┬────┬────┬────┬────┬────┬────┬────┐ │ +│ │16× │ │ │ │ │ │ │ │ │ +│ │ 16 │ │ │ │ │ │ │ │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ │ │ │ │ │ │ │ │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ │ │ │ │ │ │ │ │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ .. │ │ +│ ├────┼────┼────┼────┼────┼────┼────┼────┤ │ +│ │ │ │ │ │ │ │ │ │ │ +│ └────┴────┴────┴────┴────┴────┴────┴────┘ │ +│ │ +│ 16 wave tiles in M direction │ +│ 8 wave tiles in N direction │ +│ Total: 128 wave tiles (16×16 each) │ +└─────────────────────────────────────────────────┘ +``` + +#### How Wave Tiles Combine (C = A × B): +``` +Matrix A Matrix B (stored transposed N×K) Matrix C +(256×32) (128×32) (256×128) + +Row of A tiles: Row of B tiles: One wave tile in C: +┌────┬────┐ ┌────┬────┐ ┌────┐ +│ A₀ │ A₁ │ × │ B₀ │ B₁ │ = │ C │ (16×16) +└────┴────┘ └────┴────┘ └────┘ + 16×16 each 16×16 each + +Computation: C = A₀×B₀ᵀ + A₁×B₁ᵀ + ↑ ↑ + K=0..15 K=16..31 + +Each wave tile in C is computed by: +- Taking one row of wave tiles from A (2 tiles along K) +- Taking one row of wave tiles from B (2 tiles along K) + Note: B is stored transposed (N×K), so a "row" in storage corresponds + to a "column" in the logical B^T matrix used in computation +- Performing dot product: Σ(A_k × B_k^T) for k=0,1 +``` + +**Key Insight:** +- Each **wave tile in C** (16×16) requires a **dot product** of 2 wave tiles from A and 2 wave tiles from B +- Since B is stored transposed (N×K layout), we access **rows** of B tiles in memory +- This is the fundamental operation repeated across all 128 wave tiles in C +- Each warp computes one wave tile using MFMA instructions + +--- + +## Step 7: Create Shape, Problem, and Policy Structs + +```cpp +using PracticeGemmShape = ck_tile::PracticeGemmShape; +std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl; + +using PracticeGemmHostProblem = ck_tile:: + PracticeGemmHostProblem; + +using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; +``` + +**What's happening:** + +### 1. **Shape Struct** +Encapsulates all tile shape information (BlockTile and WaveTile dimensions). + +### 2. **Problem Struct** +Holds complete problem description: +- Data types (ADataType, BDataType, CDataType, AccDataType) +- Shape information (BlockTile, WaveTile) + +In more complex examples, this would also include: +- Data layouts (row-major, column-major) +- Mathematical operations (e.g., transposed GEMM) + +### 3. **Policy Struct** +Describes data movement and thread-to-data mapping: +- Currently contains `MakeBlock2TileMap()`: Maps thread block IDs to tile positions +- In more complex kernels, includes: + - DRAM access patterns + - LDS (Local Data Share) usage strategies + - Thread distribution within blocks + +**CK Tile Design Pattern:** +``` +Kernel = Problem + Policy + Epilogue + ↑ ↑ ↑ + (What) (How) (Post-processing) +``` + +--- + +## Step 8: Calculate Grid and Block Dimensions + +```cpp +ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * + ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); + +std::cout << "kGridSize: " << kGridSize << std::endl; + +constexpr ck_tile::index_t kBlockSize = 256; +constexpr ck_tile::index_t kBlockPerCU = 1; +``` + +**What's happening:** + +### Grid Size Calculation +```cpp +kGridSize = ceil(M / BlockTile_M) × ceil(N / BlockTile_N) + = ceil(512 / 256) × ceil(256 / 128) + = 2 × 2 + = 4 thread blocks +``` + +Our problem requires **4 thread blocks** to cover the entire output matrix C (2 blocks in M direction, 2 blocks in N direction). + +### Block Configuration +- `kBlockSize = 256`: Each thread block has 256 threads + - 256 threads / 64 threads per warp = **4 warps per block** +- `kBlockPerCU = 1`: Launch 1 block per Compute Unit (for simplicity) + +**Thread Hierarchy:** +``` +GPU +└── 1 Thread Block (Grid) + └── 256 Threads + ├── Warp 0 (threads 0-63) + ├── Warp 1 (threads 64-127) + ├── Warp 2 (threads 128-191) + └── Warp 3 (threads 192-255) +``` + +--- + +## Step 9: Create and Launch the Kernel + +```cpp +using gemm_kernel = + ck_tile::PracticeGemmKernel; + +float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel(gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_device.GetDeviceBuffer()), + static_cast(b_device.GetDeviceBuffer()), + static_cast(c_device.GetDeviceBuffer()), + M, + N, + K, + stride_a, + stride_b, + stride_c)); +``` + +**What's happening:** + +### 1. Kernel Composition +```cpp +using gemm_kernel = ck_tile::PracticeGemmKernel; +``` +The kernel is composed from Problem and Policy structs, following the CK Tile design pattern. + +### 2. Kernel Launch +`launch_kernel()` is a CK Tile utility that: +- Launches the GPU kernel using HIP runtime +- Measures execution time +- Returns average execution time in milliseconds + +### 3. Launch Parameters +- **Stream config**: `{nullptr, true, 0, 0, 1}` - default stream, timing enabled +- **Grid size**: `kGridSize = 1` - number of thread blocks +- **Block size**: `kBlockSize = 256` - threads per block +- **Shared memory**: `0` - no dynamic shared memory in this example +- **Kernel arguments**: Device pointers and problem dimensions + +### 4. Kernel Execution Flow +``` +launch_kernel() calls gemm_kernel.operator()() + ↓ +PracticeGemmKernel::operator() + ↓ +Creates tensor views over device memory + ↓ +Calls block-level pipeline + ↓ +Block pipeline calls warp-level pipeline + ↓ +Warp pipeline calls MFMA instructions + ↓ +Results written back to C matrix +``` + +--- + +## Step 10: Verify Results + +```cpp +auto pass = true; + +if(verification) +{ + // Reference gemm on CPU + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + + // Copy GPU results back to host + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + c_device.FromDevice(c_host_dev.mData.data()); + + // Compare results + pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; +} +``` + +**What's happening:** + +### 1. CPU Reference Implementation +```cpp +reference_basic_gemm<...>(a_host, b_host, c_host_ref); +``` +Computes GEMM on CPU using a simple nested loop implementation (ground truth). + +### 2. Copy GPU Results to Host +```cpp +c_device.FromDevice(c_host_dev.mData.data()); +``` +Transfers the computed result from GPU memory back to CPU for comparison. + +### 3. Error Checking +```cpp +ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); +``` +Compares GPU and CPU results element-wise with tolerance: +- **Relative error**: 1e-3 (0.1%) +- **Absolute error**: 1e-3 + +**Verification Flow:** +``` +CPU GPU +┌─────────┐ ┌─────────┐ +│ a_host │ ────────> │a_device │ +│ b_host │ ────────> │b_device │ +└─────────┘ └─────────┘ + │ │ + ↓ ↓ +reference_gemm() GPU kernel + │ │ + ↓ ↓ +┌──────────┐ ┌──────────┐ +│c_host_ref│ │c_device │ +└──────────┘ └──────────┘ + │ │ + │ ↓ + │ FromDevice() + │ │ + ↓ ↓ + └────> check_err() <───┘ + │ + ↓ + Pass/Fail +``` + +--- + +## Complete Execution Flow Summary + +``` +1. Define data types (FP16 inputs, FP32 output) + ↓ +2. Set problem size (M=256, N=128, K=32) + ↓ +3. Create host tensors and initialize with random data + ↓ +4. Allocate device memory and transfer data (CPU → GPU) + ↓ +5. Configure hierarchical tiling (BlockTile, WaveTile) + ↓ +6. Create Shape, Problem, and Policy structs + ↓ +7. Calculate grid/block dimensions (1 block, 256 threads) + ↓ +8. Compose and launch kernel (Problem + Policy) + ↓ +9. Execute GEMM on GPU + │ ├─ Block-level pipeline + │ ├─ Warp-level pipeline + │ └─ MFMA instructions + ↓ +10. Verify results (compare GPU vs CPU reference) + ↓ +11. Calculate and print performance metrics + ↓ +12. Return success/failure +``` + +--- \ No newline at end of file diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..31fa4ac3eb --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +namespace ck_tile { + +template +struct PracticeGemmBlockPipelineAGmemBGmemCreg +{ + using ADataType = typename Problem::ADataType; + using BDataType = typename Problem::BDataType; + using CDataType = typename Problem::CDataType; + using AccDataType = typename Problem::AccDataType; + + using BlockTile = typename Problem::Shape::BlockTile; + using WaveTile = typename Problem::Shape::WaveTile; + + static constexpr index_t MPerBlock = BlockTile::at(number<0>{}); + static constexpr index_t NPerBlock = BlockTile::at(number<1>{}); + static constexpr index_t KPerBlock = BlockTile::at(number<2>{}); + + static constexpr index_t MPerWave = WaveTile::at(number<0>{}); + static constexpr index_t NPerWave = WaveTile::at(number<1>{}); + static constexpr index_t KPerWave = WaveTile::at(number<2>{}); + + using BlockGemm = + remove_cvref_t())>; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLDSSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // ----------------------------------------------------------------------------------------- + // Definitions of all needed tiles + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); + + // ------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // non-prefetch + index_t iCounter = num_loop; + + while(iCounter > 0) + { + a_block_tile = load_tile(a_copy_dram_window); // from DRAM to registers + b_block_tile = load_tile(b_copy_dram_window); // from DRAM to registers + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + store_tile(a_copy_lds_window, a_block_tile); // from registers to LDS + store_tile(b_copy_lds_window, b_block_tile); // from registers to LDS + + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // from LDS to registers + block_sync_lds(); + + iCounter--; + } + + return c_block_tile; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..99c4379ad8 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" + +#include "../warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp" +#include "../warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp" + +namespace ck_tile { + +template +struct PracticeGemmBlockPipelineProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + using AccDataType = AccDataType_; + using Shape = Shape_; +}; + +struct PracticeGemmBlockPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetPracticeWaveGemmPipeline() + { + return PracticeGemmWarpPipelineASmemBSmemCreg{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kKPack = 8; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kKPack = 8; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using BlockGemm = remove_cvref_t())>; + constexpr index_t kMWarp = BlockGemm::MWarp; + constexpr index_t kNWarp = BlockGemm::NWarp; + constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size(); + + constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + using BlockGemm = remove_cvref_t())>; + constexpr index_t kMWarp = BlockGemm::MWarp; + constexpr index_t kNWarp = BlockGemm::NWarp; + constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size(); + + constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{}); + constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..ef12634e42 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +namespace ck_tile { +template +struct PracticeGemmHostPipeline +{ + using ADataType = typename Problem_::ADataType; + using BDataType = typename Problem_::BDataType; + using CDataType = typename Problem_::CDataType; + using AccDataType = typename Problem_::AccDataType; + + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using BlockTile = typename Problem::Shape::BlockTile; + using WaveTile = typename Problem::Shape::WaveTile; + + template + CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, + const BDRAMTensorView& b_dram, + CDRAMTensorView& c_dram_ref) const + { + + // Size of the entire problem + const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K + const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N + const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K + + // Size of the block tile + const auto MPerBlock = BlockTile::at(number<0>{}); + const auto NPerBlock = BlockTile::at(number<1>{}); + const auto KPerBlock = BlockTile::at(number<2>{}); + + // Number of block tile in the N direction to cover C (resultant) matrix + const auto num_tile_n = integer_divide_ceil(N, NPerBlock); + // Number of block tile in the M direction to cover C (resultant) matrix + const auto num_tile_m = integer_divide_ceil(M, MPerBlock); + + // if(get_thread_id() == 0 && get_block_id() == 0) + // { + // printf("num_tile_m: %d, num_tile_n: %d\n", num_tile_m, num_tile_n); + // printf("total number of tiles: %d\n", num_tile_m * num_tile_n); + // } + + // Get block id + const auto id_block = + get_block_id(); // 0 to (M_block/BlockTile_M) * (N_block/BlockTile_N) - 1 + + // Map block id to tile id + const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n); + + const auto tile_id = block2tile(id_block); + + const auto tile_id_m = tile_id.at(number<0>{}); + const auto tile_id_n = tile_id.at(number<1>{}); + + // if(get_thread_id() == 0 && get_block_id() == 15) + // { + // printf("tile_id_m: %d, tile_id_n: %d\n", tile_id_m, tile_id_n); + // } + + const auto tile_origin_m = tile_id_m * MPerBlock; + const auto tile_origin_n = tile_id_n * NPerBlock; + + // create a tile window over dram for A and B + const auto a_block_window = make_tile_window( + a_dram, make_tuple(number{}, number{}), {tile_origin_m, 0}); + + const auto b_block_window = make_tile_window( + b_dram, make_tuple(number{}, number{}), {tile_origin_n, 0}); + + constexpr auto block_gemm_pipeline = + Policy::template GetPracticeGemmBlockPipeline(); + + int num_loops_k = integer_divide_ceil(K, KPerBlock); + + __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; + const auto c_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); + auto c_window = make_tile_window(c_dram, + make_tuple(number{}, number{}), + {tile_origin_m, tile_origin_n}); + store_tile(c_window, c_block_tile); + } +}; +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..d66c3c8522 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" + +#include "../block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp" +#include "../block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp" + +namespace ck_tile { + +template +struct PracticeGemmHostProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + using AccDataType = AccDataType_; + using Shape = remove_cvref_t; +}; + +struct PracticeGemmHostPolicy +{ + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) + { + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPracticeGemmBlockPipeline() + { + using PracticeGemmBlockPipelineProblem_ = + PracticeGemmBlockPipelineProblem; + return PracticeGemmBlockPipelineAGmemBGmemCreg{}; + } +}; +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp new file mode 100644 index 0000000000..ee2e125e24 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "ck_tile/host.hpp" +#include "practice_gemm.hpp" +#include "reference_gemm.hpp" + +int main() +{ + // TODO: GemmTypeConfig + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using CDataType = float; + using AccDataType = float; + + // ArgParser + ck_tile::index_t M = 512; + ck_tile::index_t N = 256; + ck_tile::index_t K = 64; + ck_tile::index_t verification = 1; + + ck_tile::index_t stride_a = K; + ck_tile::index_t stride_b = K; + ck_tile::index_t stride_c = N; + + auto a_lengths = std::array{M, K}; + auto b_lengths = std::array{N, K}; + auto c_lengths = std::array{M, N}; + + auto a_strides = std::array{stride_a, 1}; + auto b_strides = std::array{stride_b, 1}; + auto c_strides = std::array{stride_c, 1}; + + // tensors on host (cpu) + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::HostTensor b_host(b_lengths, b_strides); + ck_tile::HostTensor c_host(c_lengths, c_strides); + + // initialize tensors + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); + c_host.SetZero(); + + // Print the tensors using the new print_first_n member function + // std::cout << "Tensor A (first 10 elements): "; + // a_host.print_first_n(10); + // std::cout << std::endl; + + // std::cout << "Tensor B (first 10 elements): "; + // b_host.print_first_n(10); + // std::cout << std::endl; + + // std::cout << "Tensor C (first 10 elements): "; + // c_host.print_first_n(10); + // std::cout << std::endl; + + // Create device tensors of same size as host tensors and copy data + ck_tile::DeviceMem a_device(a_host); + ck_tile::DeviceMem b_device(b_host); + ck_tile::DeviceMem c_device(c_host); + + // TODO: BlockTileConfig + // constexpr ck_tile::index_t warpSize = 64; + constexpr ck_tile::index_t kBlockSize = 256; + + using BlockTile = ck_tile::sequence<256, 128, 32>; + using WaveTile = ck_tile::sequence<16, 16, 16>; + + std::cout << "Creating PracticeGemmShape, PracticeGemmProblem, PracticeGemmPolicy" << std::endl; + using PracticeGemmShape = ck_tile::PracticeGemmShape; + std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl; + using PracticeGemmHostProblem = ck_tile:: + PracticeGemmHostProblem; + using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; + + ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * + ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); + + std::cout << "kGridSize: " << kGridSize << std::endl; + constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU + + std::cout << "kBlockSize: " << kBlockSize << std::endl; + std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl; + + using gemm_kernel = + ck_tile::PracticeGemmKernel; + + float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel(gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_device.GetDeviceBuffer()), + static_cast(b_device.GetDeviceBuffer()), + static_cast(c_device.GetDeviceBuffer()), + M, + N, + K, + stride_a, + stride_b, + stride_c)); + + auto pass = true; + + if(verification) + { + // reference gemm + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + c_device.FromDevice(c_host_dev.mData.data()); + pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp new file mode 100644 index 0000000000..88879ee221 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp" +#include "host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp" + +namespace ck_tile { + +template +struct PracticeGemmShape +{ + using BlockTile = remove_cvref_t; + using WaveTile = remove_cvref_t; + + static constexpr index_t BlockTile_M = BlockTile::at(number<0>{}); + static constexpr index_t BlockTile_N = BlockTile::at(number<1>{}); + static constexpr index_t BlockTile_K = BlockTile::at(number<2>{}); + + static constexpr index_t WaveTile_M = WaveTile::at(number<0>{}); + static constexpr index_t WaveTile_N = WaveTile::at(number<1>{}); + static constexpr index_t WaveTile_K = WaveTile::at(number<2>{}); + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + return concat('_', "practice_gemm_shape", + concat('x', BlockTile_M, BlockTile_N, BlockTile_K), + concat('x', WaveTile_M, WaveTile_N, WaveTile_K)); + // clang-format on + } +}; + +template +struct PracticeGemmKernel +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + + CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, + const typename Problem::BDataType* p_b, + typename Problem::CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c) const + { + + auto a_dram = make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{}); + + auto b_dram = make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{}); + + const auto c_dram = make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{}); + + PracticeGemmHostPipeline{}(a_dram, b_dram, c_dram); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp new file mode 100644 index 0000000000..8f975be7dc --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_basic_gemm(const ck_tile::HostTensor& a_m_k, + const ck_tile::HostTensor& b_n_k, + ck_tile::HostTensor& c_m_n) +{ + const int N = b_n_k.mDesc.get_lengths()[0]; + const int K = b_n_k.mDesc.get_lengths()[1]; + + auto f = [&](auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_n_k(n, k); + + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + c_m_n(m, n) = ck_tile::type_convert(v_acc); + } + }; + + ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(1); +} diff --git a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp new file mode 100644 index 0000000000..bf058af9c5 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp @@ -0,0 +1,195 @@ +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +namespace ck_tile { + +template +struct PracticeGemmWarpPipelineASmemBSmemCreg +{ + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using WaveGemmShape = remove_cvref_t; + + using WarpGemm = remove_cvref_t< + decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; + static constexpr index_t MWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<2>(); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == WaveGemmShape::BlockTile_M && + NPerBlock == WaveGemmShape::BlockTile_N && + KPerBlock == WaveGemmShape::BlockTile_K, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + +#if !defined(ENABLE_PREFETCH) + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // Warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == WaveGemmShape::BlockTile_M && + NPerBlock == WaveGemmShape::BlockTile_N && + KPerBlock == WaveGemmShape::BlockTile_K, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + + static_assert(std::is_same_v, "wrong!"); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp new file mode 100644 index 0000000000..2efa2bcc2a --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCReg +// Default policy class should not be templated, put template on member functions instead +struct PracticeGemmWarpPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; + + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); + } + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/CMakeLists.txt b/tutorial/ck_tile/CMakeLists.txt new file mode 100644 index 0000000000..9895f5a71d --- /dev/null +++ b/tutorial/ck_tile/CMakeLists.txt @@ -0,0 +1,7 @@ +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) + +add_subdirectory(00_copy_kernel) +add_subdirectory(01_naive_gemm) + From 92c1f4981ab1d081978c8f6132ca93949d4749e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 11 Nov 2025 22:55:33 +0100 Subject: [PATCH 022/118] [CK_BUILDER] Add grouped conv fwd ck tile traits (#3183) * [CK BUILDER] Add grouped conv fwd ck tile traits * Update instance_traits_tile_grouped_convolution_forward.hpp * Update grouped_convolution_forward_kernel.hpp --- .../ck_tile/builder/reflect/conv_traits.hpp | 3 + ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 1 + ...raits_tile_grouped_convolution_forward.hpp | 140 ++++++++++++++++++ .../builder/reflect/instance_traits_util.hpp | 81 +++++++++- .../builder/test/test_fwd_instance_traits.cpp | 123 +++++++++++++++ include/ck_tile/core/arch/arch.hpp | 10 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 6 +- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 7 + .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 7 + .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 7 + .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 7 + .../gemm_pipeline_ag_bg_cr_comp_v6.hpp | 7 + .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 7 + .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 7 + .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 7 + .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 7 + .../kernel/grouped_gemm_quant_kernel.hpp | 4 +- .../grouped_convolution_forward_kernel.hpp | 17 +++ 18 files changed, 433 insertions(+), 15 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp mode change 100755 => 100644 include/ck_tile/core/arch/arch.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 86cf11f647..4b946011c2 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -15,6 +15,9 @@ #include #include #include +#include +#include "ck_tile/ops/epilogue.hpp" +#include namespace ck_tile::reflect::conv { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index a0def3e5d9..6913889c4f 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -4,6 +4,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" // Forward declaration to avoid circular dependency diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp new file mode 100644 index 0000000000..f364b37ae5 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp @@ -0,0 +1,140 @@ +// Copyright (C) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// InstanceTraits specialization for GroupedConvolutionForwardKernel +// +// CRITICAL MAINTENANCE NOTE: +// This InstanceTraits file MUST be kept strictly in sync with the device implementation header: +// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +// "In sync" means that the template parameter order, names, and types in the declaration below +// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter +// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are +// difficult to diagnose. Always update both files together and review changes carefully. + +#pragma once + +#include "instance_traits.hpp" +#include "instance_traits_util.hpp" + +// Forward declaration to avoid circular dependency. +namespace ck_tile::device { + +template +struct GroupedConvolutionForwardKernel; + +} // namespace ck_tile::device + +namespace ck_tile { +namespace reflect { + +// Specialization for GroupedConvolutionForwardKernel +template +struct InstanceTraits> +{ + // CK Tile Conv Traits + // Spatial dimension + static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial; + // Specialization + static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType_::ConvSpecialization; + // DataType types + using InLayout = typename GroupedConvTraitsType_::InLayout; + using WeiLayout = typename GroupedConvTraitsType_::WeiLayout; + using DsLayout = typename GroupedConvTraitsType_::DsLayout; + using OutLayout = typename GroupedConvTraitsType_::OutLayout; + // Vector size + static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA; + static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB; + static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC; + // Num Groups To Merge + static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + // Split image (large tensors) + static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + + // TilePartitioner + // Block configuration + static constexpr int kMPerBlock = TilePartitioner_::MPerBlock; + static constexpr int kNPerBlock = TilePartitioner_::NPerBlock; + static constexpr int kKPerBlock = TilePartitioner_::KPerBlock; + + static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{}); + static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{}); + static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{}); + + // Data types + using ADataType = typename GemmPipeline_::ADataType; + using BDataType = typename GemmPipeline_::BDataType; + // Gemm Pipeline + using GemmPipeline = GemmPipeline_; + static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler; + static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer; + static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups; + + // Epilogue Pipeline + using AccDataType = typename EpiloguePipeline_::AccDataType; + using EDataType = typename EpiloguePipeline_::ODataType; + using DsDataType = typename EpiloguePipeline_::DsDataType; + using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise; + + // Static member function to generate instance string + static std::string instance_string() + { + std::ostringstream oss; + + // Kernel type name + oss << "GroupedConvolutionForwardKernel"; + + // Template parameters in exact order matching InstanceTraits member order + oss << "<" << kSpatialDim; // 1. NDimSpatial + oss << "," + << ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization + oss << "," << detail::layout_name(); // 3. InLayout + oss << "," << detail::layout_name(); // 4. WeiLayout + oss << "," << detail::tuple_name(); // 5. DsLayout + oss << "," << detail::layout_name(); // 6. OutLayout + oss << "," << kVectorSizeA; // 7. VectorSizeA + oss << "," << kVectorSizeB; // 8. VectorSizeB + oss << "," << kVectorSizeC; // 9. VectorSizeC + oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge + oss << "," << kEnableSplitImage; // 11. EnableSplitImage + oss << "," << kMPerBlock; // 12. MPerBlock + oss << "," << kNPerBlock; // 13. NPerBlock + oss << "," << kKPerBlock; // 14. KPerBlock + oss << "," << kMWarp; // 15. MWarp + oss << "," << kNWarp; // 16. NWarp + oss << "," << kKWarp; // 17. KWarp + oss << "," << kMWarpTile; // 18. MWarpTile + oss << "," << kNWarpTile; // 19. NWarpTile + oss << "," << kKWarpTile; // 20. KWarpTile + oss << "," << detail::type_name(); // 21. ADataType + oss << "," << detail::type_name(); // 22. BDataType + oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched + oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups + oss << "," << kNumWaveGroups; // 26. NumWaveGroups + oss << "," << detail::type_name(); // 27. AccDataType + oss << "," << detail::type_name(); // 28. EDataType + oss << "," << detail::tuple_name(); // 29. DsDataType + oss << "," + << detail::elementwise_op_name(); // 30. + // CDEElementwiseOperation + oss << ">"; + + return oss.str(); + } +}; + +} // namespace reflect +} // namespace ck_tile diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index e4d154ae10..2e918c5c2d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -28,6 +28,10 @@ #include #include #include +#include +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" namespace ck_tile::reflect::detail { @@ -38,7 +42,7 @@ namespace impl { template consteval std::string_view type_name_impl() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v || std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "fp32"; @@ -50,11 +54,11 @@ consteval std::string_view type_name_impl() return "s8"; else if constexpr(std::is_same_v) return "s32"; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || std::is_same_v) return "bf16"; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || std::is_same_v) return "fp8"; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || std::is_same_v) return "bf8"; else return std::string_view{}; // Return empty for supported types @@ -168,6 +172,17 @@ constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineSchedule } } +constexpr std::string_view pipeline_scheduler_name(ck_tile::GemmPipelineScheduler sched) +{ + using enum ck_tile::GemmPipelineScheduler; + switch(sched) + { + case Default: return "Default"; + case Intrawave: return "Intrawave"; + case Interwave: return "Interwave"; + } +} + // Convert BlockGemmPipelineVersion enum to string constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver) { @@ -206,6 +221,26 @@ constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched) } } +// Convert TailNumber enum to string +constexpr std::string_view tail_number_name(ck_tile::TailNumber tail_num) +{ + using enum ck_tile::TailNumber; + switch(tail_num) + { + case Odd: return "Odd"; + case Even: return "Even"; + case One: return "One"; + case Two: return "Two"; + case Three: return "Three"; + case Four: return "Four"; + case Five: return "Five"; + case Six: return "Six"; + case Seven: return "Seven"; + case Empty: return "Empty"; + case Full: return "Full"; + } +} + // Convert std::array to string template inline std::string array_to_string(const std::array& arr) @@ -356,17 +391,53 @@ constexpr std::string tuple_name() }(static_cast(nullptr)); } +template + requires requires { [](ck_tile::tuple*) {}(static_cast(nullptr)); } +constexpr std::string tuple_name() +{ + return [](ck_tile::tuple*) constexpr { + if constexpr(sizeof...(Ts) == 0) + { + return std::string("EmptyTuple"); + } + else if constexpr((IsLayoutType && ...)) + { + // Lambda wrapper for layout_name + auto layout_name_fn = []() { return layout_name(); }; + return detail::build_list_string("tuple", + layout_name_fn); + } + else if constexpr((IsDataType && ...)) + { + // Lambda wrapper for type_name + auto type_name_fn = []() { return type_name(); }; + return detail::build_list_string("tuple", type_name_fn); + } + else + { + static_assert((IsLayoutType && ...) || (IsDataType && ...), + "tuple elements must be all layouts or all data types, not mixed"); + return std::string{}; // unreachable + } + }(static_cast(nullptr)); +} + // Concept to check if a type is a ck::Tuple template concept IsCkTuple = requires { [](ck::Tuple*) {}(static_cast(nullptr)); }; +// Concept to check if a type is a ck_tile::tuple +template +concept IsCkTileTuple = + requires { [](ck_tile::tuple*) {}(static_cast(nullptr)); }; + // Deduces whether to use tuple_name or type_name // Handles both scalar data types and ck::Tuple types template constexpr std::string type_or_type_tuple_name() { - if constexpr(IsCkTuple) + if constexpr(IsCkTuple || IsCkTileTuple) { return tuple_name(); } diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index b57b20eb7d..af950b441c 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace { @@ -720,4 +721,126 @@ TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat) EXPECT_EQ(instance_str, expected_str); } +TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) +{ + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits<2 /*NDimSpatial*/, + ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/, + ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/, + ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/, + ck_tile::tuple<> /*DsLayout*/, + ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/, + 4 /*VectorSizeA*/, + 4 /*VectorSizeB*/, + 4 /*VectorSizeC*/, + 1 /*NumGroupsToMerge*/, + false /*EnableSplitImage*/>; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>, + ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>, + ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + false /*DoubleSmemBuffer*/, + typename GroupedConvTraitsType::AsLayoutFwd, + typename GroupedConvTraitsType::BsLayoutFwd, + typename GroupedConvTraitsType::CLayoutFwd, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + 1 /*NumWaveGroups*/>; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ck_tile::bf16_t /*InDataType*/, + ck_tile::bf16_t /*WeiDataType*/, + float /*AccDataType*/, + GemmShape, + GemmUniversalTraits, + ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/, + true /*has_hot_loop_v*/, + ck_tile::TailNumber::Full /*tail_number_v*/, + ck_tile::element_wise::PassThrough /*AElementwiseOperation*/, + ck_tile::element_wise::PassThrough /*BElementwiseOperation*/, + ck_tile::bf16_t /*OutDataType*/, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem /*DsDataType*/, + float /*AccDataType*/, + ck_tile::bf16_t /*OutDataType*/, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + ck_tile::element_wise::PassThrough /*CDElementWise*/, + 128 /*MPerBlock*/, + 128 /*NPerBlock*/, + 4 /*M_Warp*/, + 1 /*N_Warp*/, + 16 /*M_Warp_Tile*/, + 16 /*N_Warp_Tile*/, + 16 /*K_Warp_Tile*/, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + ck_tile::memory_operation_enum::set /*memory_operation*/, + 1 /*kNumWaveGroups*/, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeC>>; + + using GroupedConvFwdKernel = + ck_tile::device::GroupedConvolutionForwardKernel; + + std::string instance_str = ck_tile::reflect::instance_string(); + + std::string expected_str = "GroupedConvolutionForwardKernel" + "<2" // NDimSpatial + ",Default" // ConvSpecialization + ",NHWGC" // InLayout + ",GKYXC" // WeiLayout + ",EmptyTuple" // DsLayout + ",NHWGK" // OutLayout + ",4" // VectorSizeA + ",4" // VectorSizeB + ",4" // VectorSizeC + ",1" // NumGroupsToMerge + ",0" // EnableSplitImage + ",128" // MPerBlock + ",128" // NPerBlock + ",32" // KPerBlock + ",4" // MWarp + ",1" // NWarp + ",1" // KWarp + ",16" // MWarpTile + ",16" // NWarpTile + ",16" // KWarpTile + ",bf16" // ADataType + ",bf16" // BDataType + ",COMPUTE_V3" // BlkGemmPipelineVer + ",Intrawave" // BlkGemmPipeSched + ",0" // DoubleSmemBuffer + ",1" // NumWaveGroups + ",fp32" // AccDataType + ",bf16" // EDataType + ",EmptyTuple" // DsDataType + ",PassThrough" // CDEElementwiseOperation + ">"; + + EXPECT_EQ(instance_str, expected_str); +} + } // anonymous namespace diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp old mode 100755 new mode 100644 index 5bf8548470..b66c00e392 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -299,12 +299,12 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0) #endif } -#define CK_CONSTANT_ADDRESS_SPACE \ - __attribute__((address_space( \ +#define CK_TILE_CONSTANT_ADDRESS_SPACE \ + __attribute__((address_space( \ static_cast>(address_space_enum::constant)))) template -__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) +__device__ T* cast_pointer_to_generic_address_space(T CK_TILE_CONSTANT_ADDRESS_SPACE* p) { // cast a pointer in "Constant" address space (4) to "Generic" address space (0) // only c-style pointer cast seems be able to be compiled @@ -315,13 +315,13 @@ __device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* } template -__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) +__host__ __device__ T CK_TILE_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) { // cast a pointer in "Generic" address space (0) to "Constant" address space (4) // only c-style pointer cast seems be able to be compiled; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wold-style-cast" - return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) + return (T CK_TILE_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) #pragma clang diagnostic pop } diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 551dc6f50d..a72b1ba544 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -190,7 +190,7 @@ struct GroupedGemmKernel */ CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { - using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; + using ConstantPointer = const void CK_TILE_CONSTANT_ADDRESS_SPACE*; const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>; int occupancy; HIP_CHECK_ERROR( @@ -518,7 +518,7 @@ struct GroupedGemmKernel // For non-persistent kernels template > - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, index_t group_count) const { const index_t block_id = ck_tile::get_block_1d_id(); @@ -541,7 +541,7 @@ struct GroupedGemmKernel template , typename = void> // extra template parameter to avoid redefinition - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count) const { const index_t grid_size = ck_tile::get_grid_size(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 91da3cd27b..b293097d89 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -164,6 +164,13 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_ASYNC"; + // clang-format on + } + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Policy::template GetSmemSize(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index aaa04615fd..a1bbcbe990 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -170,6 +170,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using Base::PrefetchStages; using Base::UsePersistentKernel; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_V3"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index ff1e33bd5d..238b4e2389 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -172,6 +172,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_V4"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 7263ddd5a1..6343ff9872 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -99,6 +99,13 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 static constexpr index_t NumWarps = BlockGemmShape::NumWarps; static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{}); + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_V5"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp index 2ae9001098..5b57560f6e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp @@ -159,6 +159,13 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_V6"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index d363626efd..ba71e3b6cb 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -214,6 +214,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "MEMORY"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index eb363d59b8..8a4fb59b51 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -70,6 +70,13 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr index_t kLdsAlignmentInBytes = 16; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "BASIC_V1"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index c309f8908a..32217e0024 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -70,6 +70,13 @@ struct GemmPipelineAGmemBGmemCRegV2 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. static constexpr bool DoubleSmemBuffer = false; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "BASIC_V2"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 87f6c753b4..cae2bd0e9f 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -176,6 +176,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp; static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp; + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "PRESHUFFLE_V2"; + // clang-format on + } + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 75ac1ca6ab..32f1279e93 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -208,7 +208,7 @@ struct QuantGroupedGemmKernel */ CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { - using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*; + using ConstantPointer = const void CK_TILE_CONSTANT_ADDRESS_SPACE*; const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>; int occupancy; HIP_CHECK_ERROR( @@ -499,7 +499,7 @@ struct QuantGroupedGemmKernel template , typename = void> // extra template parameter to avoid redefinition - CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count) const { const index_t grid_size = ck_tile::get_grid_size(); diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 7e70d2b422..6de331fe6d 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -16,6 +16,10 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" +#ifdef CK_EXPERIMENTAL_BUILDER +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp" +#endif + namespace ck_tile { /// @brief The Grouped Convolution kernel device arguments. @@ -568,6 +572,19 @@ struct GroupedConvolutionForwardKernel // clang-format on } +#ifdef CK_EXPERIMENTAL_BUILDER + CK_TILE_HOST std::string GetInstanceString() const + { + static_assert(ck_tile::reflect::HasInstanceTraits, + "Specialization of instance_traits not found. Please check that a " + "specialization exists in file " + "ck_tile/builder/reflect/" + "instance_traits_tile_grouped_convolution_forward.hpp " + "for the given template parameters."); + return ck_tile::reflect::instance_string(); + } +#endif + CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) { return dim3( From 40d2ed0f2a442026c57dc17e6e7bd281b6c2535c Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 12 Nov 2025 10:26:14 +0800 Subject: [PATCH 023/118] [CK_TILE] Share partition index across threads and specify offset in load_tile()/async_load_tile()/load_tile_transpose() (#2905) * Allow sharing partition index across threads * Fix typo PartitoinIndex -> PartitionIndex * Remove C++20 'requires' usages * Add missing template arguments * Fix load_tile() overload ambiguity issue * Use SFINAE to exclude invalid arguments * Add additional offset parameter to the async_load_tile() * Remove async_load_tile() default argument to avoid ambiguity * Extract tile_window coordinate compute logic as method * Use warp-shared LDS base address in tile_window::async_load() * Add constraint to tile_window::load() templates * Fix wrong type traits is_class_v<> usages * Add missing constraint to async_load_tile() * Add missing tile_window::load() overload * Add more constraint to avoid load_tile() call ambiguity * Rename ParitionIndex as ReplacementPartitionIndex * Update pre_computed_warp_coords_ in move_extended() * Fix inconsistency between template parameters and documentation * Allow specifying pre-computed parition index * Add type straits is_sequence<> & is_tile_distribution<> * Add type straits is_tensor_view<> * Add type constraints to make_tile_window() templates * Allow passing partition_index to set_tile_if() * Allow specifying partition_index to store_tile() * Add missing template parameter of replace_bottom_tensor_view() * Allow passing partition_index to Default2DEpilogue * Make get_partition_index() public * Add _with_offset() postfix to avoid resolution error * Remove ReplacementPartitionIndex template param * Add missing comments * Add load_tile_transpose_with_offset() overload --- include/ck_tile/core/container/sequence.hpp | 11 + include/ck_tile/core/tensor/load_tile.hpp | 51 ++++- .../core/tensor/load_tile_transpose.hpp | 60 +++++- .../core/tensor/static_distributed_tensor.hpp | 41 +++- include/ck_tile/core/tensor/store_tile.hpp | 51 +++++ include/ck_tile/core/tensor/tensor_view.hpp | 15 ++ .../ck_tile/core/tensor/tile_distribution.hpp | 27 ++- .../core/tensor/tile_scatter_gather.hpp | 5 +- include/ck_tile/core/tensor/tile_window.hpp | 195 +++++++++++++++--- .../ops/epilogue/default_2d_epilogue.hpp | 41 +++- .../ck_tile/ops/reduce/block/block_reduce.hpp | 2 +- 11 files changed, 441 insertions(+), 58 deletions(-) diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index cfec2237f9..1a88a98cbf 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -214,6 +214,17 @@ CK_TILE_HOST_DEVICE static void print(const sequence&) printf(">"); } +template +struct is_sequence : std::false_type +{ +}; +template +struct is_sequence> : std::true_type +{ +}; +template +inline constexpr bool is_sequence_v = is_sequence::value; + namespace impl { template struct __integer_sequence; diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 2e9ab0f5c6..1be4259e97 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -17,6 +17,19 @@ #include "ck_tile/core/tensor/null_tensor.hpp" namespace ck_tile { +// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. +template >> +CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load_with_offset( + offset, number{}, bool_constant{}); +} template CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, @@ -49,6 +62,23 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, tile_window, elementwise, number{}, bool_constant{}); } +// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. +template > && + std::is_class_v>> +CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile, + const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load_with_offset( + offset, dst_tile, number{}, bool_constant{}); +} + template {}, bool_constant{}, bool_constant{}); } +// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. +template > && + std::is_class_v>> +CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile, + const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.async_load_with_offset( + offset, lds_tile, number{}, bool_constant{}); +} + template = {}, bool_constant = {}) { - return tile_window.async_load( - lds_tile, number{}, bool_constant{}); + return async_load_tile_with_offset( + lds_tile, tile_window, 0, number{}, bool_constant{}); } template ::distr_encoding_valid, Policy>> -CK_TILE_DEVICE auto -load_tile_transpose(const tile_window_with_static_distribution& tile_window) +CK_TILE_DEVICE auto load_tile_transpose_with_offset( + const tile_window_with_static_distribution& __restrict__ tile_window, + index_t offset) { using OutTileDstrEncode = typename OutputTileDistributionTraits< typename TileDistribution_::DstrEncode, typename BottomTensorView_::DataType>::TransposedDstrEncode; auto out_tensor = make_static_distributed_tensor( make_static_tile_distribution(OutTileDstrEncode{})); - auto trans_tensor = tile_window.template load_transpose(); + auto trans_tensor = tile_window.template load_transpose_with_offset(offset); constexpr auto input_distr = TileDistribution_{}; constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); @@ -443,4 +446,49 @@ load_tile_transpose(const tile_window_with_static_distribution, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE auto +load_tile_transpose(const tile_window_with_static_distribution& __restrict__ tile_window) +{ + return load_tile_transpose_with_offset(tile_window, 0); +} + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index b73a27c8d5..5228ad978a 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -155,11 +155,11 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTi // get X indices from tuple of tile_distributed_index<> template -CK_TILE_HOST_DEVICE constexpr auto -get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, - DistributedIndices distributed_indices) +CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices( + StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices, + decltype(get_partition_index(tile_distribution)) partition_index) { - const auto partition_index = detail::get_partition_index(tile_distribution); constexpr auto y_indices = tile_distribution.get_y_indices_from_distributed_indices(distributed_indices); @@ -170,6 +170,16 @@ get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, return x_coord.get_bottom_index(); } +// get X indices from tuple of tile_distributed_index<> +template +CK_TILE_HOST_DEVICE constexpr auto +get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices) +{ + return get_x_indices_from_distributed_indices( + tile_distribution, distributed_indices, get_partition_index(tile_distribution)); +} + template CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor& out_tensor, @@ -192,6 +202,29 @@ set_tile_if(static_distributed_tensor& out_ten }); } +template +CK_TILE_HOST_DEVICE void +set_tile_if(static_distributed_tensor& out_tensor, + DataType value, + XIndicesPredicate predicate, + decltype(get_partition_index(std::declval())) partition_index) +{ + constexpr auto out_spans = + static_distributed_tensor::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + StaticTileDistribution{}, distributed_indices, partition_index); + + if(predicate(x_indices)) + { + out_tensor(distributed_indices) = value; + } + }); + }); +} + // this function used inside span loop over template CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number) diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index d5a716664d..b535b40534 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -38,6 +39,31 @@ store_tile(tile_window_with_static_lengths& t tile_window.store(dstr_tensor); } +template +CK_TILE_DEVICE void +store_tile(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor, + decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr, + partition_index); + + tile_window.store(dstr_tensor); +} + template +CK_TILE_DEVICE void +store_tile_raw(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor, + decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr, + partition_index); + + tile_window.store_raw(dstr_tensor); +} + template +struct is_tensor_view : std::false_type +{ +}; +template +struct is_tensor_view> : std::true_type +{ +}; +template <> +struct is_tensor_view : std::true_type +{ +}; +template +inline constexpr bool is_tensor_view_v = is_tensor_view::value; + template CK_TILE_HOST_DEVICE auto get_partition_index(Distribution) { - return Distribution::_get_partition_index(); + return Distribution::get_partition_index(); } -} // namespace detail // distributed span template @@ -91,7 +89,7 @@ struct tile_distribution CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; } - CK_TILE_HOST_DEVICE static auto _get_partition_index() + CK_TILE_HOST_DEVICE static auto get_partition_index() { // only support warp-tile and block-tile static_assert(NDimP == 1 or NDimP == 2, "wrong!"); @@ -172,9 +170,9 @@ struct tile_distribution } #endif - template + template CK_TILE_HOST_DEVICE auto - calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const + calculate_index(const PartitionIndex& ps_idx = get_partition_index()) const { const auto ps_ys_idx = container_concat(ps_idx, array{0}); const auto window_adaptor_thread_coord_tmp = @@ -230,6 +228,23 @@ struct tile_distribution } }; +template +struct is_tile_distribution : std::false_type +{ +}; +template +struct is_tile_distribution> : std::true_type +{ +}; +template +inline constexpr bool is_tile_distribution_v = is_tile_distribution::value; + namespace detail { template diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 4b04fd513d..e77ca805bb 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -189,8 +189,7 @@ struct tile_scatter_gather // need investigation const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_distribution), - array{0})); + container_concat(get_partition_index(tile_distribution), array{0})); #endif BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = @@ -836,7 +835,7 @@ struct tile_scatter_gather // need investigation const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_dstr_.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_dstr_), array{0})); + container_concat(get_partition_index(tile_dstr_), array{0})); #endif BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index cfa2420f2f..1123ce7604 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -12,6 +12,7 @@ #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tensor_view.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/utility/functional.hpp" @@ -67,18 +68,54 @@ struct tile_window_with_static_distribution const typename Base::BottomTensorView& bottom_tensor_view, const typename Base::WindowLengths& window_lengths, const typename Base::BottomTensorIndex& window_origin, - const typename Base::TileDstr& tile_distribution) + const typename Base::TileDstr& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index) : pre_computed_coords_{} { - this->window_origin_ = window_origin; - this->window_lengths_ = window_lengths; - this->bottom_tensor_view_ = bottom_tensor_view; - this->tile_dstr_ = tile_distribution; + this->window_origin_ = window_origin; + this->window_lengths_ = window_lengths; + this->bottom_tensor_view_ = bottom_tensor_view; + this->tile_dstr_ = tile_distribution; + + pre_computed_coords_ = + prepare_coords(bottom_tensor_view, window_origin, tile_distribution, partition_index); + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + auto use_lane_id_0 = partition_index; + use_lane_id_0[1] = 0; + + pre_computed_warp_coords_ = + prepare_coords(bottom_tensor_view, window_origin, tile_distribution, use_lane_id_0); + } + } + + CK_TILE_DEVICE constexpr tile_window_with_static_distribution( + const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::WindowLengths& window_lengths, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution) + : tile_window_with_static_distribution(bottom_tensor_view, + window_lengths, + window_origin, + tile_distribution, + get_partition_index(tile_distribution)) + { + } + + CK_TILE_DEVICE constexpr auto + prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index) const + { + array, NumCoord> + coords; + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_distribution), - array{0})); + container_concat(partition_index, multi_index{0})); typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); @@ -105,18 +142,31 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - pre_computed_coords_(iCoord) = - make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + coords(iCoord) = make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); }); + + return coords; } template CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const + { + return load_with_offset( + 0, number{}, bool_constant{}); + } + + template + CK_TILE_DEVICE auto load_with_offset(index_t offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - load(dst_tensor, number{}, bool_constant{}); + load_with_offset(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -236,6 +286,19 @@ struct tile_window_with_static_distribution CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, number = {}, bool_constant = {}) const + { + load_with_offset( + 0, dst_tensor, number{}, bool_constant{}); + } + + template >>> + CK_TILE_DEVICE auto load_with_offset(index_t offset, + DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const { using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; @@ -258,7 +321,7 @@ struct tile_window_with_static_distribution // read from bottom tensor const vector_t vec_value = this->get_bottom_tensor_view().template get_vectorized_elements( - bottom_tensor_thread_coord, 0, bool_constant{}); + bottom_tensor_thread_coord, offset, bool_constant{}); // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( @@ -450,10 +513,12 @@ struct tile_window_with_static_distribution template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, - number = {}, - bool_constant = {}) const + bool oob_conditional_check = true, + typename = std::enable_if_t>>> + CK_TILE_DEVICE auto async_load_with_offset(index_t offset, + LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; using LdsDataType = typename LdsTileWindow::DataType; @@ -472,12 +537,15 @@ struct tile_window_with_static_distribution auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0]; + auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1]; + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; // Use precomputed window origin auto lds_bottom_tensor_thread_idx = - window_origin + window_adaptor_thread_coord.get_bottom_index(); + window_origin + window_adaptor_warp_coord.get_bottom_index(); // Use precomputed tensor descriptor const auto lds_coord = @@ -490,7 +558,7 @@ struct tile_window_with_static_distribution this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, - number<0>{}, + offset, bool_constant{}); // Move thread coordinate if not last access @@ -503,18 +571,33 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys); } }); }); } template - CK_TILE_DEVICE auto load_transpose() const + CK_TILE_DEVICE auto load_transpose(number = {}, + bool_constant = {}) const + { + return this->template load_transpose_with_offset( + 0, number{}, bool_constant{}); + } + + template + CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - this->template load_transpose( - dst_tensor, number{}, bool_constant{}); + this->template load_transpose_with_offset(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -522,9 +605,10 @@ struct tile_window_with_static_distribution typename DistributedTensor, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true> - CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor, - number = {}, - bool_constant = {}) const + CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset, + DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const { using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; @@ -550,7 +634,7 @@ struct tile_window_with_static_distribution const vector_t vec_value = this->get_bottom_tensor_view() .template get_transpose_vectorized_elements( - bottom_tensor_thread_coord, 0); + bottom_tensor_thread_coord, offset); // write into distributed tensor static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto orig_idx_ys = generate_tuple( @@ -862,16 +946,26 @@ struct tile_window_with_static_distribution pre_computed_coords_(iCoord)(I1), step); }); + + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_warp_coords_(iCoord)(I1), + step); + }); + } } CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&) { // TODO: this use less register for FA, but more register for GEMM // need investigation - const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - this->tile_dstr_.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(this->tile_dstr_), - array{0})); + const auto window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(this->tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(get_partition_index(this->tile_dstr_), + array{0})); typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); @@ -908,6 +1002,12 @@ struct tile_window_with_static_distribution // per-thread coordinate for bottom tensor array, NumCoord> pre_computed_coords_; + // pre_computed_warp_coords_ exists only in the global memory tile_window + std::conditional_t< + Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global, + array, NumCoord>, + std::byte> + pre_computed_warp_coords_; }; // TODO: use strategy @@ -929,6 +1029,27 @@ make_tile_window(const TensorView_& tensor_view, tensor_view, window_lengths, origin, tile_distribution}; } +template && + is_tile_distribution_v>> +CK_TILE_DEVICE constexpr auto +make_tile_window(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index, + number = {}) +{ + return tile_window_with_static_distribution, + remove_cvref_t, + remove_cvref_t, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution, partition_index}; +} + // this version can't be called in a constexpr context template +CK_TILE_DEVICE constexpr auto +make_tile_window(const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution, + decltype(get_partition_index(tile_distribution)) partition_index) +{ + return make_tile_window(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + partition_index); +} + template CK_TILE_DEVICE constexpr auto make_tile_window_raw(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution) { - auto w = make_tile_window(tile_window.get_bottom_tensor_view(), - tile_window.get_window_lengths(), - tile_window.get_window_origin(), - tile_distribution); + auto w = make_tile_window(tile_window, tile_distribution); w.init_raw(); return w; } diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 2843966cd7..8cf47c46e7 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -93,13 +93,27 @@ struct Default2DEpilogue const DsDramWindows& ds_dram_windows, void* = nullptr) const { + constexpr bool is_partition_index = + std::is_convertible_v; + const auto storeOrUpdateTile = [&](const auto& o_tile) { // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { if constexpr(MemoryOperation == memory_operation_enum::set) { - store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + if constexpr(is_partition_index) + { + store_tile_raw(o_dram_window_tmp, + cast_tile(o_tile), + /*partition_index=*/ds_dram_windows); + } + else + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); + } } else { @@ -111,16 +125,35 @@ struct Default2DEpilogue { if constexpr(MemoryOperation == memory_operation_enum::set) { - store_tile(o_dram_window_tmp, cast_tile(o_tile)); + if constexpr(is_partition_index) + { + store_tile(o_dram_window_tmp, + cast_tile(o_tile), + /*partition_index=*/ds_dram_windows); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_tile)); + } } else { - update_tile(o_dram_window_tmp, cast_tile(o_tile)); + if constexpr(is_partition_index) + { + update_tile(o_dram_window_tmp, + cast_tile(o_tile), + /*partition_index=*/ds_dram_windows); + } + else + { + update_tile(o_dram_window_tmp, cast_tile(o_tile)); + } } } }; - if constexpr(!std::is_same_v && Problem::NumDTensor >= 1) + if constexpr(!std::is_same_v && !is_partition_index && + Problem::NumDTensor >= 1) { using elementwise_result_t = decltype(load_tile( make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(), diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 7a10d1fa56..2fd8a48eee 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -32,7 +32,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, constexpr index_t idim_p_lane = NDimP - 1; - const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution()); + const auto ps_idx = get_partition_index(acc_tensor.get_tile_distribution()); const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size(); From 299c9bca1bee2ef77bb78878bcdd9d11a13564e5 Mon Sep 17 00:00:00 2001 From: Yashvardhan Agarwal Date: Wed, 12 Nov 2025 17:30:20 +0200 Subject: [PATCH 024/118] [CK_Tile] Pooling example readme update (#3174) * pooling example readme update - The updated readme explains the transformations of the pooling kernel using a mermaid diagram * Update example/ck_tile/36_pooling/README.md Co-authored-by: spolifroni-amd * resolve comments --------- Co-authored-by: spolifroni-amd --- example/ck_tile/36_pooling/README.md | 110 +++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/example/ck_tile/36_pooling/README.md b/example/ck_tile/36_pooling/README.md index ab49b57095..4417e03734 100644 --- a/example/ck_tile/36_pooling/README.md +++ b/example/ck_tile/36_pooling/README.md @@ -2,6 +2,116 @@ This folder contains example for the pooling operator using ck_tile tile-programming implementation. Currently the pooling kernel only supports 2D and 3D pooling. +## Tensor Descriptor Transformations + +The pooling kernel transforms the input tensor into 2D format suitable for reduction. This section explains the transformation pipeline for both 2D and 3D pooling operations. + +### 3D Pooling Transformations + +For 3D pooling, the input tensor has shape `(N, D, H, W, C)` where: +- `N`: batch size +- `D`: depth dimension +- `H`: height dimension +- `W`: width dimension +- `C`: channel dimension + +The transformations convert this 5D tensor into a 2D tensor where rows represent output positions (M) and columns represent pooling window elements (K). + +```mermaid +graph TD + %% Input Tensor: (N, D, H, W, C) + Input["Input Tensor
(N, D, H, W, C)"] + style Input fill:#e1f5fe + + %% Pass-through N dimension + PassN["Pass-through N
(batch size)"] + style PassN fill:#f3e5f5 + Input --> PassN + + %% Pad spatial dimensions + PadD["Pad D
(depth with left/right padding)"] + style PadD fill:#fff9c4 + Input --> PadD + + PadH["Pad H
(height with left/right padding)"] + style PadH fill:#fff9c4 + Input --> PadH + + PadW["Pad W
(width with left/right padding)"] + style PadW fill:#fff9c4 + Input --> PadW + + %% Pass-through C dimension + PassC["Pass-through C
(channels)"] + style PassC fill:#f3e5f5 + Input --> PassC + + %% Embed sliding windows + EmbedD["Embed D
window(Z) × output_positions(Dₒ)"] + style EmbedD fill:#fff3e0 + PadD --> EmbedD + + EmbedH["Embed H
window(Y) × output_positions(Hₒ)"] + style EmbedH fill:#fff3e0 + PadH --> EmbedH + + EmbedW["Embed W
window(X) × output_positions(Wₒ)"] + style EmbedW fill:#fff3e0 + PadW --> EmbedW + + %% Merge into 2D matrix + MergeM["Merge M
(N, Dₒ, Hₒ, Wₒ, C)
→ output positions"] + style MergeM fill:#e8f5e9 + PassN --> MergeM + EmbedD --> MergeM + EmbedH --> MergeM + EmbedW --> MergeM + PassC --> MergeM + + MergeK["Merge K
(Z, Y, X)
→ window elements"] + style MergeK fill:#e8f5e9 + EmbedD --> MergeK + EmbedH --> MergeK + EmbedW --> MergeK + + %% Final padding for block alignment + PadM["Right-pad M
(for block alignment)"] + style PadM fill:#fff9c4 + MergeM --> PadM + + PadK["Right-pad K
(for block alignment)"] + style PadK fill:#fff9c4 + MergeK --> PadK + + %% Result + Result["2D Matrix
(M × K)"] + style Result fill:#c8e6c9 + PadM --> Result + PadK --> Result +``` + +**Transformation Steps:** +1. **Padding**: Apply left and right padding to spatial dimensions (D, H, W) to handle boundary conditions +2. **Sliding Windows**: Use embed transforms to create sliding windows across each spatial dimension, expanding each dimension into (window_size, output_positions) +3. **Reshaping**: Merge all dimensions into a 2D matrix where: + - M dimension = N × Dₒ × Hₒ × Wₒ × C (total output positions) + - K dimension = Z × Y × X (elements per pooling window) +4. **Block Alignment**: Apply right padding to ensure M and K dimensions are aligned to block size + +### 2D Pooling Transformations + +2D pooling follows the same transformation pipeline but operates on 4D tensors with shape `(N, H, W, C)`. The process is identical except: +- Only H and W dimensions are padded and embedded +- K dimension merges only (Y, X) window elements +- M dimension merges (N, Hₒ, Wₒ, C) + +### Output Tensor Transformations + +The output tensor transformations are simpler: +- Merge all output dimensions (N, Dₒ/Hₒ, Wₒ, C) into a single M dimension +- Apply right padding for block alignment +- The result is a 1D tensor that maps directly to the M dimension of the computation matrix + ## build ``` # in the root of ck_tile From 7414a0f4d43fbb581421c236c7c68bf2ba7664ca Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Wed, 12 Nov 2025 20:23:54 +0100 Subject: [PATCH 025/118] Wmma support for gemm_reduce (#3145) * Initial implementation GEMM+Reduce: - device struct - epilogue struct * Fix tests, improve profiler and add initial instances * Add instances * Fix compilation error * Address review comments * Fix logging --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../device_gemm_reduce_wmma_cshuffle_v3.hpp | 661 ++++++++++++++++++ .../grid/epilogue_cshuffle_v3_reduce_wmma.hpp | 470 +++++++++++++ .../grid/epilogue_cshuffle_v3_wmma_base.hpp | 1 + .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 24 + .../gpu/gemm_reduce/CMakeLists.txt | 7 +- ..._f16_f16_f16_f32_f32_km_kn_mn_instance.cpp | 88 +++ ..._f16_f16_f16_f32_f32_km_nk_mn_instance.cpp | 88 +++ ..._f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp | 88 +++ ..._f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp | 86 +++ .../profiler/profile_gemm_reduce_impl.hpp | 55 +- test/gemm_reduce/CMakeLists.txt | 10 +- ...duce_fp16_xdl.cpp => gemm_reduce_fp16.cpp} | 2 +- 12 files changed, 1568 insertions(+), 12 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp rename test/gemm_reduce/{gemm_reduce_fp16_xdl.cpp => gemm_reduce_fp16.cpp} (96%) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..166c1a7581 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -0,0 +1,661 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_reduce_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid, + const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops, + const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + auto epilogue_args = + EpilogueType(p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = p_reduces_grid; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + + using CDEShuffleBlockTransferScalarPerVectors = + Sequence; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple<>, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using ReduceTrait = ReduceTrait_; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_c_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_reduces_grid_{p_reduces_grid}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + StrideA_{StrideA}, + StrideB_{StrideB}, + StrideC_{StrideC}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_c_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t StrideA_; + index_t StrideB_; + index_t StrideC_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{}, + static_cast(arg.p_c_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1); + + float ave_time = 0; + + index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.KRaw_); + + const auto Run = [&](const auto& kernel) { + // Note: cache flushing not supported + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.p_reduces_grid_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_); + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = + kernel_gemm_reduce_wmma_cshuffle_v3; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline setting"); + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = + kernel_gemm_reduce_wmma_cshuffle_v3; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v1 setting"); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Even) + { + const auto kernel = + kernel_gemm_reduce_wmma_cshuffle_v3; + Run(kernel); + } + else if(TailNum == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_reduce_wmma_cshuffle_v3; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v3 setting"); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Device implementation supports only gfx11 and gfx12! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "FP8 and BF8 not supported on gfx11! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Without padding, K must be divisible by AK1 and BK1! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{}, + static_cast(arg.p_c_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmReduce_Wmma_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp new file mode 100644 index 0000000000..c2bd65f134 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp @@ -0,0 +1,470 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" + +namespace ck { + +template +struct ReduceTrait_ +{ + using ReduceAccDataType_ = ReduceAccDataType; + using ReducePtrsGlobal_ = ReducePtrsGlobal; + using ReduceOperations_ = ReduceOperations; + using ReduceInElementwiseOperations_ = ReduceInElementwiseOperations; + using ReduceAccElementwiseOperations_ = ReduceAccElementwiseOperations; + using ReduceGlobalMemoryDataOperation_ = ReduceGlobalMemoryDataOperation; + using CReduceThreadClusterLengths_MPerBlock_NPerBlock_ = + CReduceThreadClusterLengths_MPerBlock_NPerBlock; + static constexpr index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_ = + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock; + static constexpr index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_ = + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock; +}; + +template +struct EpilogueReduceCShuffle + : EpilogueCShuffleBase +{ + using Base = EpilogueCShuffleBase< + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + MPerBlock, + NPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + CDEElementwiseOperation, + ThisThreadBlock, + BlockwiseGemmPipe>; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + using Base::GetCShuffleLDSDescriptor; + using Base::GetVgprToLDSEpilogueDescriptor; + using Base::I0; + using Base::I1; + using Base::I3; + using Base::NumDTensor; + + // assume Reduce is packed tensor + __device__ static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + __device__ static constexpr auto + MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return reduce_grid_desc_mblock_mperblock; + } + + __device__ EpilogueReduceCShuffle( + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_, + const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_, + const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_, + const index_t MRaw_) + : p_reduces_grid(p_reduces_grid_), + reduce_in_element_ops(reduce_in_element_ops_), + reduce_out_element_ops(reduce_out_element_ops_), + MRaw(MRaw_), + reduce_grid_desc_m{MakeReduceGridDescriptor_M(MRaw)} + { + } + + template + __device__ void Run(CThreadBuf& c_thread_buf, + DsGridPointer p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + CDEElementwiseOperation& cde_element_op, + const index_t& block_m_id, + const index_t& block_n_id) + { + auto reduce_grid_desc_mblock_mperblock = + MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + BlockwiseGemmPipe:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // LDS buffer + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + // Thread transfer Vgpr to LDS + auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor(); + + // Space Filling Curve Vgpr + constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{}; + + // Space Filling Curve Vmem + constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{}; + + // Block descriptor + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + GetCShuffleLDSDescriptor(); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // Thread transfer LDS to Vmem + auto cde_shuffle_block_copy_lds_and_global = + Base::template GetLDSToVmemEpilogueDescriptor( + c_ds_desc_refs, + e_grid_desc_mblock_mperblock_nblock_nperblock, + cde_element_op, + block_m_id, + block_n_id); + + // tuple of reference to C/Ds tensor buffers + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength( + I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert( + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) * + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) == + BlockSize, + "wrong!"); + + static_assert( + (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) % + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) == + 0 && + (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) % + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) / + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) / + ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1); + + static constexpr index_t NumReduce = ReduceTrait::ReducePtrsGlobal_::Size(); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = + make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + typename ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_{}, + Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + CShuffleDataType, + typename ReduceTrait::ReduceAccDataType_, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + typename ReduceTrait::ReduceAccDataType_, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + ReduceTrait::CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_, + ReduceTrait::ReduceGlobalMemoryDataOperation_::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_m_id, // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); + + // CShuffle and Store + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block loads its C data from LDS, D from global, applies elementwise + // operation and stores result E to global + cde_shuffle_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + static_for<0, NumReduce, 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = + remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = ReduceOperation::template GetIdentityValue< + typename ReduceTrait::ReduceAccDataType_>(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_cde_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); + + // move on E + cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); + } + }); + } + + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid; + typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops; + typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops; + index_t MRaw; + ReduceGridDesc_M reduce_grid_desc_m; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp index d2c6c92c9f..30f81b7411 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 56f09cee96..020d0110cf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -25,6 +25,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" namespace ck { @@ -622,6 +623,29 @@ struct GridwiseGemm_wmma_cshuffle_v3_base BlockwiseGemmPipe, BlockSize>; + template + using EpilogueReduceCShuffle = EpilogueReduceCShuffle< + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + MPerBlock, + NPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + CDEElementwiseOperation, + ThisThreadBlock, + BlockwiseGemmPipe, + GemmSpec, + BlockSize, + ReduceTrait>; + template __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt index 7ee3efe7f5..12d1026ea1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -1,7 +1,12 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_reduce_instance device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp + + device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..d92e84380f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //##############################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..b21531e394 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //##############################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..d32e663b1c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //##############################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + + // v1 Interwave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + + // v3 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..f4013b5414 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using ReducePtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceInElementOps = ck::Tuple; +using ReduceOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout|AData| BData| EData| Acc| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| Reduce| ReduceIn| ReduceAcc| ReduceGlobal| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEShuffleBlockTransferClusterLengths| CDEShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm| + //##############################| | | | Type| Type| Type| DataType| DataType| DataType| | Elementwise| Elementwise| Elementwise| Operation| Elementwise| Elementwise| MemoryData| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Operations| Operations| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // v1 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>, + // v1 Interwave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Interwave, BlockGemmPipelineVersion::v1>, + // v3 Intrawave + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_reduce_impl.hpp index 74a1b60fe3..c870a95cbe 100644 --- a/profiler/include/profiler/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_reduce_impl.hpp @@ -34,6 +34,7 @@ using ReduceOutElementOps = ck::Tuple; using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>; +#ifdef CK_USE_XDL void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( std::vector&); @@ -45,6 +46,20 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( std::vector&); +#endif +#ifdef CK_USE_WMMA +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); +#endif } // namespace instance } // namespace device @@ -211,33 +226,61 @@ bool profile_gemm_reduce_impl(int do_verification, is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); +#endif } else if constexpr(is_same::value && is_same::value && is_same::value) { +#ifdef CK_USE_XDL ck::tensor_operation::device::instance:: add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( gemm_ptrs); +#endif +#ifdef CK_USE_WMMA + ck::tensor_operation::device::instance:: + add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); +#endif } } @@ -274,6 +317,8 @@ bool profile_gemm_reduce_impl(int do_verification, auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + std::string gemm_name = gemm_ptr->GetTypeString(); + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { ++num_kernel; @@ -289,8 +334,6 @@ bool profile_gemm_reduce_impl(int do_verification, float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - std::string gemm_name = gemm_ptr->GetTypeString(); - std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + @@ -317,9 +360,9 @@ bool profile_gemm_reduce_impl(int do_verification, reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data()); reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data()); - ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); - ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); - ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + pass = pass & ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result); + pass = pass & ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result); if(do_log) { @@ -346,7 +389,7 @@ bool profile_gemm_reduce_impl(int do_verification, } else { - std::cout << "does not support this GEMM problem" << std::endl; + std::cout << gemm_name << ": does not support this GEMM problem" << std::endl; } } diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index 121ecde609..ae2246e628 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,4 +1,6 @@ -add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance) -endif() \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance) + endif() +endif() diff --git a/test/gemm_reduce/gemm_reduce_fp16_xdl.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp similarity index 96% rename from test/gemm_reduce/gemm_reduce_fp16_xdl.cpp rename to test/gemm_reduce/gemm_reduce_fp16.cpp index b1f2c36c9f..30657c87c5 100644 --- a/test/gemm_reduce/gemm_reduce_fp16_xdl.cpp +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include From 3784c0e7c395af214fdddd5f702691b354bfe8d4 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 12 Nov 2025 11:47:07 -0800 Subject: [PATCH 026/118] add permissions for /tmp folder (#3201) --- Dockerfile.aiter | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Dockerfile.aiter b/Dockerfile.aiter index dab3f9588d..94591f9012 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -17,4 +17,6 @@ RUN pip install pandas zmq einops ninja && \ useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \ chown -R jenkins:jenkins /home/jenkins && \ chmod -R a+rwx /home/jenkins && \ + chown -R jenkins:jenkins /tmp && \ + chmod -R a+rwx /tmp && \ sudo usermod -aG irc jenkins From 9342365713f6c8601e35921e7adeba9769b784b7 Mon Sep 17 00:00:00 2001 From: John Afaganis Date: Wed, 12 Nov 2025 17:05:53 -0700 Subject: [PATCH 027/118] Add C++17 deprecation warning to CHANGELOG.md (#3203) * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md --- CHANGELOG.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 213631721f..44d0837b40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## (Unreleased) Composable Kernel for ROCm - -### Added +## Composable Kernel 1.1.0 for ROCm 7.2.0 ### Added * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM @@ -32,6 +30,10 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added an optional template parameter `Arch` (`gfx9_t`, `gfx12_t` etc.) to `make_kernel` to support linking multiple object files that have the same kernel compiled for different architectures. * FMHA examples and tests can be built for multiple architectures (gfx9, gfx950, gfx12) at the same time. +### Upcoming changes + +* To enhance capabilities and user experience, Composable Kernel will adopt C++20 features in ROCm 8.0, updating the minimum compiler requirement to C++20. Please ensure your development environment meets this requirement for a seamless transition. + ## Composable Kernel 1.1.0 for ROCm 7.1.0 ### Added From 797ddfa41e5e2c45f9eea9e6c969ba528e5a9c39 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 12 Nov 2025 19:07:28 -0500 Subject: [PATCH 028/118] chore(copyright): update copyright header for test_data directory (#3194) * chore(copyright): update copyright header for tile_engine directory * chore(copyright): update copyright header for script directory * chore(copyright): update copyright header for test_data directory --- test_data/generate_model_configs.py | 530 ++++++++++++++-------------- test_data/generate_test_dataset.sh | 2 +- test_data/miopen_to_csv.py | 2 +- test_data/run_model_with_miopen.py | 2 +- 4 files changed, 268 insertions(+), 268 deletions(-) diff --git a/test_data/generate_model_configs.py b/test_data/generate_model_configs.py index 567870fd73..f3c47e3715 100644 --- a/test_data/generate_model_configs.py +++ b/test_data/generate_model_configs.py @@ -1,265 +1,265 @@ -#!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -""" -Generate Model Configuration Combinations for MIOpen Testing - -This script generates all possible combinations of model parameters -and saves them as CSV files that can be read by the shell script. -""" - -import csv -import argparse - - -def generate_2d_configs(mode="full"): - """Generate all 2D model configuration combinations - - Args: - mode: 'small' for minimal set (~50 configs), 'half' for reduced set (~250 configs), 'full' for comprehensive set (~500 configs) - """ - - # Define parameter ranges - models_2d = [ - "resnet18", - "resnet34", - "resnet50", - "mobilenet_v2", - "mobilenet_v3_large", - "mobilenet_v3_small", - "vgg11", - "vgg16", - "vgg19", - "alexnet", - "googlenet", - "densenet121", - "densenet161", - "squeezenet1_0", - "squeezenet1_1", - "shufflenet_v2_x1_0", - ] - - if mode == "small": - # Minimal set for quick testing - batch_sizes = [1, 8] # Just two batch sizes - # Very limited input dimensions - only 2 key sizes - input_dims = [ - (224, 224), # Standard (most common) - (256, 256), # Medium - ] - # Use only first 3 models for minimal testing - models_2d = models_2d[:3] # Only resnet18, resnet34, resnet50 - elif mode == "half": - # Reduced set for faster testing - batch_sizes = [1, 8, 32] # Small, medium, large - # Reduced input dimensions - 5 key sizes - input_dims = [ - (64, 64), # Small - (224, 224), # Standard (most common) - (512, 512), # Large - (224, 320), # Rectangular - (227, 227), # AlexNet preferred - ] - else: # full mode - # More comprehensive but still limited - batch_sizes = [1, 4, 8, 16, 32] - # More dimensions but skip some redundant ones - input_dims = [ - (64, 64), - (128, 128), - (224, 224), - (256, 256), - (512, 512), # Square - (224, 320), - (320, 224), # Rectangular (reduced from 4) - (227, 227), # AlexNet preferred - (299, 299), # Inception preferred - ] - - precisions = ["fp32"] # , 'fp16', 'bf16'] - channels = [3] # Most models expect RGB - - configs = [] - config_id = 1 - - # Generate all combinations (but limit to reasonable subset) - for model in models_2d: - for batch_size in batch_sizes: - for height, width in input_dims: - for precision in precisions: - # Skip some combinations to keep dataset manageable - if batch_size > 16 and height > 256: - continue # Skip large batch + large image combinations - if precision != "fp32" and batch_size < 8: - continue # Skip mixed precision with tiny batches - - config_name = f"{model}_b{batch_size}_{height}x{width}_{precision}" - - config = { - "config_name": config_name, - "model": model, - "batch_size": batch_size, - "channels": channels[0], - "height": height, - "width": width, - "precision": precision, - } - - configs.append(config) - config_id += 1 - - return configs - - -def generate_3d_configs(mode="full"): - """Generate all 3D model configuration combinations - - Args: - mode: 'small' for minimal set (~10 configs), 'half' for reduced set (~50 configs), 'full' for comprehensive set (~100 configs) - """ - - models_3d = ["r3d_18", "mc3_18", "r2plus1d_18"] - - if mode == "small": - # Minimal set for quick testing - batch_sizes = [1, 4] # Just two batch sizes - temporal_sizes = [8] # Only smallest temporal size - # Very limited spatial dimensions - input_dims = [ - (112, 112), # Standard for 3D - ] - # Use only first model for minimal testing - models_3d = models_3d[:1] # Only r3d_18 - elif mode == "half": - # Reduced set for faster testing - batch_sizes = [1, 4, 8] # Skip batch_size=2 - temporal_sizes = [8, 16] # Skip 32 (most expensive) - # Reduced spatial dimensions - input_dims = [ - (112, 112), # Small (common for video) - (224, 224), # Standard - (224, 320), # Rectangular - ] - else: # full mode - # More comprehensive but still reasonable - batch_sizes = [1, 2, 4, 8] # 3D models are more memory intensive - temporal_sizes = [8, 16, 32] - # More dimensions - input_dims = [ - (112, 112), - (224, 224), - (256, 256), # Standard sizes - (224, 320), - (320, 224), # Rectangular - ] - - precisions = ["fp32"] # , 'fp16'] # Skip bf16 for 3D to reduce combinations - channels = [3] - - configs = [] - - for model in models_3d: - for batch_size in batch_sizes: - for temporal_size in temporal_sizes: - for height, width in input_dims: - for precision in precisions: - # Skip very large combinations - if batch_size > 4 and temporal_size > 16: - continue - if batch_size > 2 and height > 224: - continue - - config_name = f"{model}_b{batch_size}_t{temporal_size}_{height}x{width}_{precision}" - - config = { - "config_name": config_name, - "model": model, - "batch_size": batch_size, - "channels": channels[0], - "temporal_size": temporal_size, - "height": height, - "width": width, - "precision": precision, - } - - configs.append(config) - - return configs - - -def save_configs_to_csv(configs, filename, config_type): - """Save configurations to CSV file""" - - if not configs: - print(f"No {config_type} configurations generated") - return - - fieldnames = list(configs[0].keys()) - - with open(filename, "w", newline="\n", encoding="utf-8") as csvfile: - csvfile.write(f"# {config_type} Model Configurations\n") - csvfile.write(f"# Generated {len(configs)} configurations\n") - - writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") - writer.writeheader() - - for config in configs: - writer.writerow(config) - - print(f"Generated {len(configs)} {config_type} configurations → {filename}") - - -def main(): - parser = argparse.ArgumentParser( - description="Generate model configuration combinations" - ) - parser.add_argument( - "--output-2d", - type=str, - default="model_configs_2d.csv", - help="Output file for 2D configurations", - ) - parser.add_argument( - "--output-3d", - type=str, - default="model_configs_3d.csv", - help="Output file for 3D configurations", - ) - parser.add_argument( - "--mode", - choices=["small", "half", "full"], - default="full", - help="Configuration mode: small (~60 total), half (~300 total) or full (~600 total) (default: half)", - ) - parser.add_argument( - "--limit", - type=int, - help="Limit number of configurations per type (for testing)", - ) - - args = parser.parse_args() - - print(f"Generating {args.mode} model configurations...") - - print("Generating 2D model configurations...") - configs_2d = generate_2d_configs(mode=args.mode) - if args.limit: - configs_2d = configs_2d[: args.limit] - save_configs_to_csv(configs_2d, args.output_2d, "2D") - - print("Generating 3D model configurations...") - configs_3d = generate_3d_configs(mode=args.mode) - if args.limit: - configs_3d = configs_3d[: args.limit] - save_configs_to_csv(configs_3d, args.output_3d, "3D") - - print( - f"\nTotal configurations: {len(configs_2d)} 2D + {len(configs_3d)} 3D = {len(configs_2d) + len(configs_3d)}" - ) - print("\nTo use these configurations:") - print(" Update generate_test_dataset.sh to read from these CSV files") - - -if __name__ == "__main__": - main() +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Generate Model Configuration Combinations for MIOpen Testing + +This script generates all possible combinations of model parameters +and saves them as CSV files that can be read by the shell script. +""" + +import csv +import argparse + + +def generate_2d_configs(mode="full"): + """Generate all 2D model configuration combinations + + Args: + mode: 'small' for minimal set (~50 configs), 'half' for reduced set (~250 configs), 'full' for comprehensive set (~500 configs) + """ + + # Define parameter ranges + models_2d = [ + "resnet18", + "resnet34", + "resnet50", + "mobilenet_v2", + "mobilenet_v3_large", + "mobilenet_v3_small", + "vgg11", + "vgg16", + "vgg19", + "alexnet", + "googlenet", + "densenet121", + "densenet161", + "squeezenet1_0", + "squeezenet1_1", + "shufflenet_v2_x1_0", + ] + + if mode == "small": + # Minimal set for quick testing + batch_sizes = [1, 8] # Just two batch sizes + # Very limited input dimensions - only 2 key sizes + input_dims = [ + (224, 224), # Standard (most common) + (256, 256), # Medium + ] + # Use only first 3 models for minimal testing + models_2d = models_2d[:3] # Only resnet18, resnet34, resnet50 + elif mode == "half": + # Reduced set for faster testing + batch_sizes = [1, 8, 32] # Small, medium, large + # Reduced input dimensions - 5 key sizes + input_dims = [ + (64, 64), # Small + (224, 224), # Standard (most common) + (512, 512), # Large + (224, 320), # Rectangular + (227, 227), # AlexNet preferred + ] + else: # full mode + # More comprehensive but still limited + batch_sizes = [1, 4, 8, 16, 32] + # More dimensions but skip some redundant ones + input_dims = [ + (64, 64), + (128, 128), + (224, 224), + (256, 256), + (512, 512), # Square + (224, 320), + (320, 224), # Rectangular (reduced from 4) + (227, 227), # AlexNet preferred + (299, 299), # Inception preferred + ] + + precisions = ["fp32"] # , 'fp16', 'bf16'] + channels = [3] # Most models expect RGB + + configs = [] + config_id = 1 + + # Generate all combinations (but limit to reasonable subset) + for model in models_2d: + for batch_size in batch_sizes: + for height, width in input_dims: + for precision in precisions: + # Skip some combinations to keep dataset manageable + if batch_size > 16 and height > 256: + continue # Skip large batch + large image combinations + if precision != "fp32" and batch_size < 8: + continue # Skip mixed precision with tiny batches + + config_name = f"{model}_b{batch_size}_{height}x{width}_{precision}" + + config = { + "config_name": config_name, + "model": model, + "batch_size": batch_size, + "channels": channels[0], + "height": height, + "width": width, + "precision": precision, + } + + configs.append(config) + config_id += 1 + + return configs + + +def generate_3d_configs(mode="full"): + """Generate all 3D model configuration combinations + + Args: + mode: 'small' for minimal set (~10 configs), 'half' for reduced set (~50 configs), 'full' for comprehensive set (~100 configs) + """ + + models_3d = ["r3d_18", "mc3_18", "r2plus1d_18"] + + if mode == "small": + # Minimal set for quick testing + batch_sizes = [1, 4] # Just two batch sizes + temporal_sizes = [8] # Only smallest temporal size + # Very limited spatial dimensions + input_dims = [ + (112, 112), # Standard for 3D + ] + # Use only first model for minimal testing + models_3d = models_3d[:1] # Only r3d_18 + elif mode == "half": + # Reduced set for faster testing + batch_sizes = [1, 4, 8] # Skip batch_size=2 + temporal_sizes = [8, 16] # Skip 32 (most expensive) + # Reduced spatial dimensions + input_dims = [ + (112, 112), # Small (common for video) + (224, 224), # Standard + (224, 320), # Rectangular + ] + else: # full mode + # More comprehensive but still reasonable + batch_sizes = [1, 2, 4, 8] # 3D models are more memory intensive + temporal_sizes = [8, 16, 32] + # More dimensions + input_dims = [ + (112, 112), + (224, 224), + (256, 256), # Standard sizes + (224, 320), + (320, 224), # Rectangular + ] + + precisions = ["fp32"] # , 'fp16'] # Skip bf16 for 3D to reduce combinations + channels = [3] + + configs = [] + + for model in models_3d: + for batch_size in batch_sizes: + for temporal_size in temporal_sizes: + for height, width in input_dims: + for precision in precisions: + # Skip very large combinations + if batch_size > 4 and temporal_size > 16: + continue + if batch_size > 2 and height > 224: + continue + + config_name = f"{model}_b{batch_size}_t{temporal_size}_{height}x{width}_{precision}" + + config = { + "config_name": config_name, + "model": model, + "batch_size": batch_size, + "channels": channels[0], + "temporal_size": temporal_size, + "height": height, + "width": width, + "precision": precision, + } + + configs.append(config) + + return configs + + +def save_configs_to_csv(configs, filename, config_type): + """Save configurations to CSV file""" + + if not configs: + print(f"No {config_type} configurations generated") + return + + fieldnames = list(configs[0].keys()) + + with open(filename, "w", newline="\n", encoding="utf-8") as csvfile: + csvfile.write(f"# {config_type} Model Configurations\n") + csvfile.write(f"# Generated {len(configs)} configurations\n") + + writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") + writer.writeheader() + + for config in configs: + writer.writerow(config) + + print(f"Generated {len(configs)} {config_type} configurations → {filename}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate model configuration combinations" + ) + parser.add_argument( + "--output-2d", + type=str, + default="model_configs_2d.csv", + help="Output file for 2D configurations", + ) + parser.add_argument( + "--output-3d", + type=str, + default="model_configs_3d.csv", + help="Output file for 3D configurations", + ) + parser.add_argument( + "--mode", + choices=["small", "half", "full"], + default="full", + help="Configuration mode: small (~60 total), half (~300 total) or full (~600 total) (default: half)", + ) + parser.add_argument( + "--limit", + type=int, + help="Limit number of configurations per type (for testing)", + ) + + args = parser.parse_args() + + print(f"Generating {args.mode} model configurations...") + + print("Generating 2D model configurations...") + configs_2d = generate_2d_configs(mode=args.mode) + if args.limit: + configs_2d = configs_2d[: args.limit] + save_configs_to_csv(configs_2d, args.output_2d, "2D") + + print("Generating 3D model configurations...") + configs_3d = generate_3d_configs(mode=args.mode) + if args.limit: + configs_3d = configs_3d[: args.limit] + save_configs_to_csv(configs_3d, args.output_3d, "3D") + + print( + f"\nTotal configurations: {len(configs_2d)} 2D + {len(configs_3d)} 3D = {len(configs_2d) + len(configs_3d)}" + ) + print("\nTo use these configurations:") + print(" Update generate_test_dataset.sh to read from these CSV files") + + +if __name__ == "__main__": + main() diff --git a/test_data/generate_test_dataset.sh b/test_data/generate_test_dataset.sh index 1124311feb..e9c4937445 100755 --- a/test_data/generate_test_dataset.sh +++ b/test_data/generate_test_dataset.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT # Generate Comprehensive Convolution Test Dataset for CK diff --git a/test_data/miopen_to_csv.py b/test_data/miopen_to_csv.py index d6a85e1e3f..e4ca42adeb 100644 --- a/test_data/miopen_to_csv.py +++ b/test_data/miopen_to_csv.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ diff --git a/test_data/run_model_with_miopen.py b/test_data/run_model_with_miopen.py index 9eee3b53fb..2e655fb82c 100644 --- a/test_data/run_model_with_miopen.py +++ b/test_data/run_model_with_miopen.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT """ From 9af30f04b65b8e50877d01ce8377a8cd581d462c Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Thu, 13 Nov 2025 00:56:18 -0600 Subject: [PATCH 029/118] Ck tile engine commons (#3166) * Moving Preshuffle to commons * Fixing Common Validations * Addressing Review Comments * Partial Rebasing * Partial Rebasing * Partial Rebasing * Rebasing Complete --- ...tion_utils.py => gemm_validation_utils.py} | 450 ++++++++++++++-- tile_engine/ops/gemm/codegen_utils.py | 210 -------- tile_engine/ops/gemm/gemm_instance_builder.py | 3 +- .../gemm_multi_d_instance_builder.py | 5 +- .../commons/validation_utils.py | 483 ------------------ .../gemm_preshuffle_instance_builder.py | 36 +- 6 files changed, 434 insertions(+), 753 deletions(-) rename tile_engine/ops/commons/{validation_utils.py => gemm_validation_utils.py} (58%) delete mode 100644 tile_engine/ops/gemm/codegen_utils.py delete mode 100644 tile_engine/ops/gemm_preshuffle/commons/validation_utils.py diff --git a/tile_engine/ops/commons/validation_utils.py b/tile_engine/ops/commons/gemm_validation_utils.py similarity index 58% rename from tile_engine/ops/commons/validation_utils.py rename to tile_engine/ops/commons/gemm_validation_utils.py index 5787446e8c..1b4a7191cd 100644 --- a/tile_engine/ops/commons/validation_utils.py +++ b/tile_engine/ops/commons/gemm_validation_utils.py @@ -1,16 +1,19 @@ #!/usr/bin/env python -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT - -""" -Validation utilities for GEMM kernel generation. -Extracted from tile_engine_develop for consistency. -""" +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. import logging from typing import Tuple, List -# Element size mapping for different data types +GEMM_PIPELINES = ["mem", "compv3", "compv4"] + +GEMM_PRESHUFFLE_PIPELINES = ["preshufflev2"] + +LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", +} + ELEMENT_SIZE_MAP = { "fp16": 2, "bf16": 2, @@ -47,9 +50,79 @@ WARP_SUPPORTED_COMBINATIONS = { ], } -# [TODO] Handle this while moving code to commons -# Supported warp tile combinations for different GPU architectures and data types -WARP_TILE_SUPPORTED_COMBINATIONS = { +GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx90a": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { "gfx90a": { "fp16_fp16_fp16": [ [32, 32, 8], @@ -132,7 +205,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { }, } -# Unsupported trait combinations TRAIT_UNSUPPORTED_COMBINATIONS = { ("compv3", "cshuffle", "interwave"), ("compv3", "default", "interwave"), @@ -220,7 +292,7 @@ def validate_lds_capacity( matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) total_tile_in_lds = matrix_a_size + matrix_b_size - max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16 if total_tile_in_lds > max_tile_size: error_msg = ( @@ -234,7 +306,7 @@ def validate_lds_capacity( return True, "" -def validate_warp_tile_combination( +def validate_gemm_warp_tile_combination( warp_tile_m: int, warp_tile_n: int, warp_tile_k: int, @@ -250,7 +322,51 @@ def validate_warp_tile_combination( current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] # Check if we have GPU-specific combinations - gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + gpu_warp_tile_combinations = GEMM_WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not gpu_warp_tile_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") + return True, "" + + # Check if we have combinations for this data type combination + allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) + if not allowed_combinations: + # For data type combinations not in the list, be permissive + logging.debug( + f"No warp tile combinations found for data types: {warp_tile_key}" + ) + return True, "" + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" + ) + return False, error_msg + + return True, "" + + +def validate_gemm_preshuffle_warp_tile_combination( + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + gpu_name: str, +) -> Tuple[bool, str]: + """Validate warp tile combination against GPU-specific supported combinations.""" + + # Construct the key for looking up supported combinations + warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + # Check if we have GPU-specific combinations + gpu_warp_tile_combinations = GEMM_PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS.get( + gpu_name, {} + ) if not gpu_warp_tile_combinations: # If GPU not recognized, try to be permissive but log warning logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") @@ -292,7 +408,6 @@ def is_tile_config_valid( pipeline: str, layout: str, gpu_target: str, - trait_name: str = None, ) -> bool: """ Comprehensive tile configuration validation. @@ -349,37 +464,81 @@ def is_tile_config_valid( logging.debug(f"LDS validation failed: {lds_error}") return False - # Validate whole workgroup cover configuration - wr_cover_valid, wg_cover_error = validate_whole_wg_cover_configuration( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - layout, - a_datatype, - b_datatype, - ) - if not wr_cover_valid: - logging.debug( - f"Whole workgroup cover configuration validation failed: {wg_cover_error}" + if pipeline in GEMM_PIPELINES: + gemm_valid, gemm_valid_error = validate_gemm( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + layout, + gpu_target, ) - return False + if not gemm_valid: + logging.debug(f"GEMM validation failed: {gemm_valid_error}") + return False - # Validate warp tile combination - warp_tile_valid, warp_tile_error = validate_warp_tile_combination( - warp_tile_m, - warp_tile_n, - warp_tile_k, - a_datatype, - b_datatype, - c_datatype, - gpu_target, - ) - if not warp_tile_valid: - logging.debug(f"Warp tile validation failed: {warp_tile_error}") - return False + # Validate warp tile combination + warp_tile_valid, warp_tile_error = validate_gemm_warp_tile_combination( + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + gpu_target, + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False + + elif pipeline in GEMM_PRESHUFFLE_PIPELINES: + preshuffle_valid, preshuffle_valid_error = validate_gemm_preshuffle( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + layout, + gpu_target, + ) + if not preshuffle_valid: + logging.debug( + f"GEMM Preshuffle validation failed: {preshuffle_valid_error}" + ) + return False + + # Validate warp tile combination + warp_tile_valid, warp_tile_error = ( + validate_gemm_preshuffle_warp_tile_combination( + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + gpu_target, + ) + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False return True @@ -398,12 +557,6 @@ def get_dtype_string(datatype: str) -> str: return dtype_map.get(datatype, "float") -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - - def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: """ Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. @@ -600,3 +753,200 @@ def get_global_vector_load_size( return int(PackedSize * 2 / element_size(DataType)) else: return PackedSize + + +def validate_gemm( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + layout: str, + gpu_target: str, + trait_name: str = None, +) -> bool: + # GEMM Validation + # Validate whole workgroup cover configuration + whole_workgroup_cover_valid, whole_workgroup_cover_error = ( + validate_whole_wg_cover_configuration( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + layout, + a_datatype, + b_datatype, + ) + ) + if not whole_workgroup_cover_valid: + logging.debug( + f"Whole workgroup cover configuration validation failed: {whole_workgroup_cover_error}" + ) + return False, whole_workgroup_cover_error + + return True, "" + + +def validate_gemm_preshuffle( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + layout: str, + gpu_target: str, + trait_name: str = None, +) -> bool: + # Preshuffle Validations + # Validate vector load alignment + m_iter_per_warp = tile_m / (warp_m * warp_tile_m) + vector_valid, vector_error = validate_vector_load_alignment( + warp_tile_m, + warp_tile_k, + a_datatype, + m_iter_per_warp, + wave_size=64, + vector_load_size=16, + ) + if not vector_valid: + logging.debug(f"Vector load alignment failed: {vector_error}") + return False, "vector load alignment error" + + # Validate M0, M1, M2 configuration for matrix A row-major layout + m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration( + tile_m, + tile_k, + warp_m, + warp_n, + warp_k, + a_datatype, + vector_load_size=16, + warp_size=64, + ) + if not m0_m1_m2_valid: + logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}") + return False, m0_m1_m2_error + + return True, "" + + +def validate_vector_load_alignment( + wg_m: int, + wg_k: int, + a_datatype: str, + m_iter_per_warp: int, + wave_size: int, + vector_load_size: int, +) -> Tuple[bool, str]: + try: + # Calculate the memory access pattern size + a_element_size = element_size(a_datatype) + access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size + + # Check if it's aligned to vector load size + if access_size % vector_load_size != 0: + error_msg = ( + f"Vector load alignment violation: " + f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) " + f"% {vector_load_size} = {access_size % vector_load_size} != 0. " + f"Access size: {access_size} bytes" + ) + return False, error_msg + + return True, "" + + except Exception as e: + return False, f"Error in vector load validation: {str(e)}" + + +def validate_m0_m1_m2_configuration( + tile_m: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + a_datatype: str, + vector_load_size: int = 16, + warp_size: int = 64, +) -> Tuple[bool, str]: + """ + Validate M0, M1, M2 configuration for matrix A row-major layout. + This ensures proper memory access pattern alignment. + """ + try: + # Validation for A as row-major + MPerBlock = tile_m + + # Calculate K1 using element size + K1 = vector_load_size / element_size(a_datatype) + + # Check if K1 is valid (must be integer) + if K1 != int(K1): + return ( + False, + f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})", + ) + K1 = int(K1) + + # Calculate K0 + if tile_k % K1 != 0: + return False, f"tile_k({tile_k}) must be divisible by K1({K1})" + K0 = tile_k // K1 + + # Calculate M2 + if warp_size % K0 != 0: + return False, f"warp_size({warp_size}) must be divisible by K0({K0})" + M2 = warp_size // K0 + + # Calculate number of warps and block size + NumWarps = warp_m * warp_n * warp_k + BlockSize = NumWarps * warp_size + + # Calculate M0 (assuming get_warp_size() returns warp_size) + M0 = BlockSize // warp_size # This should equal NumWarps + + # Calculate M1 + if (M2 * M0) == 0: + return False, f"M2({M2}) * M0({M0}) cannot be zero" + + if MPerBlock % (M2 * M0) != 0: + return ( + False, + f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}", + ) + M1 = MPerBlock // (M2 * M0) + + # Validate the assertion: M0 * M1 * M2 == MPerBlock + calculated_m_per_block = M0 * M1 * M2 + if calculated_m_per_block != MPerBlock: + error_msg = ( + f"Incorrect M0, M1, M2 configuration! " + f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). " + f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}" + ) + return False, error_msg + + return True, "" + + except ZeroDivisionError as e: + return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}" + except Exception as e: + return False, f"Error in M0/M1/M2 validation: {str(e)}" diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py deleted file mode 100644 index eecc2228a6..0000000000 --- a/tile_engine/ops/gemm/codegen_utils.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -# -*- coding: utf-8 -*- - -""" -Mappings and utility functions for kernel code generation. -""" - -DATA_TYPE_MAP = { - "fp32": "float", - "fp16": "ck_tile::half_t", - "bf16": "ck_tile::bf16_t", - "int8": "ck_tile::int8_t", - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "int4": "ck_tile::pk_int4_t", - "int32": "ck_tile::int32_t", -} - -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - -DEFAULT_EPILOGUE = """ - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - kPadM, - kPadN, - WarpTileM, - WarpTileN, - WarpTileK, - UniversalGemmProblem::TransposeC, - true, - memory_operation>>; -""" - -CSHUFFLE_EPILOGUE = """ - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - WarpM, - WarpN, - WarpTileM, - WarpTileN, - WarpTileK, - UniversalGemmProblem::TransposeC, - memory_operation>>; -""" - -PIPELINE_MAP = { - "mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"], - "compv3": [ - "ck_tile::BaseGemmPipelineAgBgCrCompV3", - "ck_tile::GemmPipelineAgBgCrCompV3", - ], - "compv4": [ - "ck_tile::BaseGemmPipelineAgBgCrCompV4", - "ck_tile::GemmPipelineAgBgCrCompV4", - ], -} - -SCHEDULER_MAP = { - "interwave": "ck_tile::GemmPipelineScheduler::Interwave", - "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", -} - -EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE} - - -def BOOL_MAP(b_): - return {True: "true", False: "false"}[bool(b_)] - - -# To Do: add some more supported combinations -warp_tile_supported_combinations = { - "gfx90a": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], - }, - "gfx942": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], - "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], - }, - "gfx950": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [4, 64, 16], - [64, 4, 16], - ], - "fp8_fp8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 32], - [16, 16, 64], - [16, 16, 128], - [32, 32, 64], - ], - "bf8_bf8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 64], - [16, 16, 32], - [16, 16, 128], - [32, 32, 64], - ], - "fp8_bf8_fp16": [ - [16, 16, 128], - [32, 32, 64], - ], - "bf8_fp8_fp16": [ - [16, 16, 128], - [32, 32, 64], - ], - }, - "gfx1201": { - "fp16_fp16_fp16": [ - [16, 16, 16], - ], - }, -} - -# To Do: remove some unsupported combinations -trait_unsupported_combinations = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), -} - - -ELEMENT_SIZE_MAP = { - "fp16": 2, - "bf16": 2, - "int8": 1, - "fp8": 1, - "bf8": 1, - "int4": 0.5, - "int32": 4, -} - - -def element_size(data_type: str) -> float: - """Calculate the size (in bytes) of a single element for given data type.""" - data_type = data_type.lower() - if data_type not in ELEMENT_SIZE_MAP: - raise ValueError(f"Unsupported data type: {data_type}") - return ELEMENT_SIZE_MAP[data_type] diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 8885c821c1..d450f20105 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -21,7 +21,8 @@ def _import_validation_utils(): # Load the module dynamically spec = importlib.util.spec_from_file_location( - "validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py") + "validation_utils", + os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), ) validation_utils = importlib.util.module_from_spec(spec) spec.loader.exec_module(validation_utils) diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py index cc167fb75f..06da7ea8a2 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -21,7 +21,8 @@ def _import_validation_utils(): # Load the module dynamically spec = importlib.util.spec_from_file_location( - "validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py") + "validation_utils", + os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), ) validation_utils = importlib.util.module_from_spec(spec) spec.loader.exec_module(validation_utils) @@ -824,7 +825,7 @@ def main(): elif elementwise_function == "add": function_name = "MultiDAdd" elif elementwise_function == "passthrough": - function_name = "PassThrough" # TODO Change this + function_name = "PassThrough" args.elementwise_function = function_name diff --git a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py deleted file mode 100644 index 70ce3b0d72..0000000000 --- a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py +++ /dev/null @@ -1,483 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -""" -Validation utilities for GEMM kernel generation. -Extracted from tile_engine_develop for consistency. -""" - -import logging -from typing import Tuple, List - -# Element size mapping for different data types -ELEMENT_SIZE_MAP = { - "fp16": 2, - "bf16": 2, - "int8": 1, - "fp8": 1, - "bf8": 1, - "int4": 0.5, - "int32": 4, - "fp32": 4, - "fp64": 8, -} - -# [TODO] Handle this while moving code to commons -# Supported warp tile combinations for different GPU architectures and data types -WARP_TILE_SUPPORTED_COMBINATIONS = { - "gfx90a": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], - }, - "gfx942": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], - "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], - }, - "gfx950": { - "fp16_fp16_fp16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "bf16_bf16_bf16": [ - [32, 32, 8], - [16, 16, 16], - [32, 32, 16], - [16, 16, 32], - [64, 4, 16], - ], - "fp8_fp8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 32], - [16, 16, 64], - [16, 16, 128], - [32, 32, 64], - ], - "bf8_bf8_fp16": [ - [32, 32, 16], - [32, 32, 32], - [16, 16, 64], - [16, 16, 32], - [16, 16, 128], - [32, 32, 64], - ], - }, -} - -# Unsupported trait combinations -TRAIT_UNSUPPORTED_COMBINATIONS = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), -} - - -def element_size(data_type: str) -> float: - """Calculate the size (in bytes) of a single element for given data type.""" - data_type = data_type.lower() - if data_type not in ELEMENT_SIZE_MAP: - raise ValueError(f"Unsupported data type: {data_type}") - return ELEMENT_SIZE_MAP[data_type] - - -def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool: - """Check if a trait combination is valid.""" - if pipeline not in ["preshufflev2"]: - raise ValueError("Accepted pipeline values are: ['preshufflev2']") - if epilogue not in ["default", "cshuffle"]: - return ValueError("Accepted epilogue values are: ['default', 'cshuffle']") - if scheduler not in ["default"]: - return ValueError("Accepted scheduler values are: ['default']") - return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS - - -def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: - """Validate warp configuration.""" - return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] - - -def validate_dimension_alignment( - tile_m: int, - tile_n: int, - tile_k: int, - warp_m: int, - warp_n: int, - warp_k: int, - warp_tile_m: int, - warp_tile_n: int, - warp_tile_k: int, -) -> Tuple[bool, List[str]]: - """Check if tile dimensions are properly aligned with warp dimensions.""" - alignment_issues = [] - - if tile_m % (warp_m * warp_tile_m) != 0: - alignment_issues.append( - f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" - ) - if tile_n % (warp_n * warp_tile_n) != 0: - alignment_issues.append( - f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" - ) - if tile_k % (warp_k * warp_tile_k) != 0: - alignment_issues.append( - f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" - ) - - return len(alignment_issues) == 0, alignment_issues - - -def validate_lds_capacity( - tile_m: int, - tile_n: int, - tile_k: int, - a_datatype: str, - b_datatype: str, - pipeline: str, -) -> Tuple[bool, str]: - """Validate LDS capacity requirements.""" - matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) - matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) - total_tile_in_lds = matrix_a_size + matrix_b_size - - max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16 - - if total_tile_in_lds > max_tile_size: - error_msg = ( - f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " - f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" - f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" - f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" - ) - return False, error_msg - - return True, "" - - -def validate_warp_tile_combination( - warp_tile_m: int, - warp_tile_n: int, - warp_tile_k: int, - a_datatype: str, - b_datatype: str, - c_datatype: str, - gpu_name: str, -) -> Tuple[bool, str]: - """Validate warp tile combination against GPU-specific supported combinations.""" - - # Construct the key for looking up supported combinations - warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" - current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] - - # Check if we have GPU-specific combinations - gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) - if not gpu_warp_tile_combinations: - # If GPU not recognized, try to be permissive but log warning - logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") - return True, "" - - # Check if we have combinations for this data type combination - allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) - if not allowed_combinations: - # For data type combinations not in the list, be permissive - logging.debug( - f"No warp tile combinations found for data types: {warp_tile_key}" - ) - return True, "" - - # Check if current combination is in the allowed list - if current_combination not in allowed_combinations: - error_msg = ( - f"Invalid warp tile combination: {current_combination} not in allowed list. " - f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" - ) - return False, error_msg - - return True, "" - - -def is_tile_config_valid( - tile_m: int, - tile_n: int, - tile_k: int, - warp_m: int, - warp_n: int, - warp_k: int, - warp_tile_m: int, - warp_tile_n: int, - warp_tile_k: int, - a_datatype: str, - b_datatype: str, - c_datatype: str, - pipeline: str, - gpu_target: str, - trait_name: str = None, -) -> bool: - """ - Comprehensive tile configuration validation. - Returns True if configuration is valid, False otherwise. - """ - # Basic sanity checks - if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: - return False - if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: - return False - if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: - return False - - # Check that warp tiles fit within block tiles - if warp_m * warp_tile_m > tile_m: - return False - if warp_n * warp_tile_n > tile_n: - return False - if warp_k * warp_tile_k > tile_k: - return False - - # Validate vector load alignment - m_iter_per_warp = tile_m / (warp_m * warp_tile_m) - vector_valid, vector_error = validate_vector_load_alignment( - warp_tile_m, - warp_tile_k, - a_datatype, - m_iter_per_warp, - wave_size=64, - vector_load_size=16, - ) - if not vector_valid: - logging.debug(f"Vector load alignment failed: {vector_error}") - return False - - # Validate M0, M1, M2 configuration for matrix A row-major layout - m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration( - tile_m, - tile_k, - warp_m, - warp_n, - warp_k, - a_datatype, - vector_load_size=16, - warp_size=64, - ) - if not m0_m1_m2_valid: - logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}") - return False - - # Validate warp configuration - if not validate_warp_configuration(warp_m, warp_n, warp_k): - logging.debug( - f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})" - ) - return False - - # Validate dimension alignment - is_aligned, alignment_issues = validate_dimension_alignment( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) - if not is_aligned: - logging.debug( - f"Dimension alignment failed: {', '.join(alignment_issues)}. " - f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " - f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - ) - return False - - # Validate LDS capacity - lds_valid, lds_error = validate_lds_capacity( - tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline - ) - if not lds_valid: - logging.debug(f"LDS validation failed: {lds_error}") - return False - - # Validate warp tile combination - warp_tile_valid, warp_tile_error = validate_warp_tile_combination( - warp_tile_m, - warp_tile_n, - warp_tile_k, - a_datatype, - b_datatype, - c_datatype, - gpu_target, - ) - if not warp_tile_valid: - logging.debug(f"Warp tile validation failed: {warp_tile_error}") - return False - - return True - - -def validate_vector_load_alignment( - wg_m: int, - wg_k: int, - a_datatype: str, - m_iter_per_warp: int, - wave_size: int, - vector_load_size: int, -) -> Tuple[bool, str]: - try: - # Calculate the memory access pattern size - a_element_size = element_size(a_datatype) - access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size - - # Check if it's aligned to vector load size - if access_size % vector_load_size != 0: - error_msg = ( - f"Vector load alignment violation: " - f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) " - f"% {vector_load_size} = {access_size % vector_load_size} != 0. " - f"Access size: {access_size} bytes" - ) - return False, error_msg - - return True, "" - - except Exception as e: - return False, f"Error in vector load validation: {str(e)}" - - -def validate_m0_m1_m2_configuration( - tile_m: int, - tile_k: int, - warp_m: int, - warp_n: int, - warp_k: int, - a_datatype: str, - vector_load_size: int = 16, - warp_size: int = 64, -) -> Tuple[bool, str]: - """ - Validate M0, M1, M2 configuration for matrix A row-major layout. - This ensures proper memory access pattern alignment. - """ - try: - # Validation for A as row-major - MPerBlock = tile_m - - # Calculate K1 using element size - K1 = vector_load_size / element_size(a_datatype) - - # Check if K1 is valid (must be integer) - if K1 != int(K1): - return ( - False, - f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})", - ) - K1 = int(K1) - - # Calculate K0 - if tile_k % K1 != 0: - return False, f"tile_k({tile_k}) must be divisible by K1({K1})" - K0 = tile_k // K1 - - # Calculate M2 - if warp_size % K0 != 0: - return False, f"warp_size({warp_size}) must be divisible by K0({K0})" - M2 = warp_size // K0 - - # Calculate number of warps and block size - NumWarps = warp_m * warp_n * warp_k - BlockSize = NumWarps * warp_size - - # Calculate M0 (assuming get_warp_size() returns warp_size) - M0 = BlockSize // warp_size # This should equal NumWarps - - # Calculate M1 - if (M2 * M0) == 0: - return False, f"M2({M2}) * M0({M0}) cannot be zero" - - if MPerBlock % (M2 * M0) != 0: - return ( - False, - f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}", - ) - M1 = MPerBlock // (M2 * M0) - - # Validate the assertion: M0 * M1 * M2 == MPerBlock - calculated_m_per_block = M0 * M1 * M2 - if calculated_m_per_block != MPerBlock: - error_msg = ( - f"Incorrect M0, M1, M2 configuration! " - f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). " - f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}" - ) - return False, error_msg - - return True, "" - - except ZeroDivisionError as e: - return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}" - except Exception as e: - return False, f"Error in M0/M1/M2 validation: {str(e)}" - - -# [TODO] Handle this while moving code to commons Add more datatype to this function if needed -def get_dtype_string(datatype: str) -> str: - """Get C++ type string for datatype""" - dtype_map = { - "fp16": "ck_tile::fp16_t", - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "bf16": "ck_tile::bf16_t", - "fp32": "float", - "fp64": "double", - } - return dtype_map.get(datatype, "float") - - -LAYOUT_MAP = { - "r": "ck_tile::tensor_layout::gemm::RowMajor", - "c": "ck_tile::tensor_layout::gemm::ColumnMajor", -} - - -def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]: - """ - Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. - """ - code = str(layout_code).strip().lower() - - a_layout = LAYOUT_MAP[code[0]] - b_layout = LAYOUT_MAP[code[1]] - c_layout = LAYOUT_MAP[code[2]] - return a_layout, b_layout, c_layout diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index 9ce6d8cb25..654a039b9c 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -8,15 +8,34 @@ import itertools import logging import multiprocessing import concurrent.futures - from pathlib import Path +import importlib.util -from commons.validation_utils import ( - is_tile_config_valid, - is_trait_combination_valid, - get_dtype_string, - get_abc_layouts, -) + +def _import_validation_utils(): + """Import validation utilities from commons directory.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = os.path.dirname(current_dir) + + # Load the module dynamically + spec = importlib.util.spec_from_file_location( + "validation_utils", + os.path.join(parent_dir, "commons", "gemm_validation_utils.py"), + ) + validation_utils = importlib.util.module_from_spec(spec) + spec.loader.exec_module(validation_utils) + + return validation_utils + + +# Import validation functions +_validation_utils = _import_validation_utils() +is_tile_config_valid = _validation_utils.is_tile_config_valid +is_trait_combination_valid = _validation_utils.is_trait_combination_valid +get_dtype_string = _validation_utils.get_dtype_string +get_abc_layouts = _validation_utils.get_abc_layouts + +logging.basicConfig(level=logging.INFO) class GemmPreshuffleKernelBuilder: @@ -305,6 +324,8 @@ class GemmPreshuffleKernelBuilder: b_datatype = self.datatype c_datatype = self.datatype + layout = self.layout + # Special handling for certain data types if self.datatype in ["fp8", "bf8"]: c_datatype = "fp16" @@ -324,6 +345,7 @@ class GemmPreshuffleKernelBuilder: b_datatype, c_datatype, pipeline, + layout, self.gpu_target, ) From 6fd8ddabe798b1856a92049c5979611246b5b367 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Thu, 13 Nov 2025 00:43:40 -0700 Subject: [PATCH 030/118] [CK TILE GEMM] Refactor block_scale_gemm examples (#3181) * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * [CK TILE GEMM] Refactor block_scale_gemm examples - Set quant group size to (1, 1, 64) for targets excluding gfx950, where warp tile size (16, 16, 128) is incompatible. --- .../38_block_scale_gemm/CMakeLists.txt | 15 +- example/ck_tile/38_block_scale_gemm/README.md | 42 +- .../gemm_aquant_quantgrouped.cpp | 53 +++ .../gemm_bquant_quantgrouped_prefill_bf8.cpp | 47 ++ ...gemm_bquant_quantgrouped_prefill_bf8i4.cpp | 49 ++ .../gemm_bquant_quantgrouped_prefill_fp8.cpp | 47 ++ ...gemm_bquant_quantgrouped_prefill_fp8i4.cpp | 49 ++ ...quant_quantgrouped_preshuffleb_prefill.cpp | 53 +++ .../38_block_scale_gemm/gemm_quant.cpp | 130 ++++++ .../38_block_scale_gemm/gemm_quant_basic.cpp | 428 ------------------ .../38_block_scale_gemm/gemm_quant_rowcol.cpp | 30 ++ .../38_block_scale_gemm/gemm_quant_tensor.cpp | 30 ++ .../38_block_scale_gemm/gemm_utils.hpp | 54 +-- .../run_gemm_quant_example.inc | 273 ++++++++++- 14 files changed, 805 insertions(+), 495 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant.cpp delete mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index b1ae9369a2..932acb72fd 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -6,8 +6,19 @@ endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") - add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) - target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + set(EXE_NAME tile_example_gemm_quant) + add_executable(${EXE_NAME} EXCLUDE_FROM_ALL + gemm_quant.cpp + gemm_aquant_quantgrouped.cpp + gemm_bquant_quantgrouped_prefill_bf8i4.cpp + gemm_bquant_quantgrouped_prefill_fp8i4.cpp + gemm_bquant_quantgrouped_prefill_bf8.cpp + gemm_bquant_quantgrouped_prefill_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb_prefill.cpp + gemm_quant_rowcol.cpp + gemm_quant_tensor.cpp + ) + target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 496697ca32..64ecebd15a 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -40,23 +40,31 @@ This will result in an executable `build/bin/tile_example_gemm_quant_basic` ## example ``` args: - -b batch size (default:1) - -m m dimension (default:1024) - -n n dimension (default:2048) - -k k dimension (default:64) - -a_layout Tensor A data layout (default: R) - -b_layout Tensor B data layout (default: C) - -c_layout Tensor C data layout (default: R) - -stride_a Tensor A stride (default:0) - -stride_b Tensor B stride (default:0) - -stride_c Tensor C stride (default:0) - -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1) - -e Absolute error tolerance (default:1e-5) - -prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8) - -warmup number of iterations before benchmark the kernel (default:10) - -repeat number of iterations to benchmark the kernel (default:100) - -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -quant_mode Which quant method to use (aquant, bquant, tensor, rowcol) + -h Print help message (default:false) + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:2048) + -a_layout A tensor data layout - Row or Column (default:R) + -b_layout B tensor data layout - Row or Column (default:C) + -bq_layout Bq tensor data layout - Row or Column (default:C) + -c_layout C tensor data layout - Row or Column (default:R) + -stride_a Tensor A stride (default:0) + -stride_q Tensor AQ stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8) + -warmup Number of iterations before benchmarking the kernel (default:50) + -repeat Number of iterations to benchmark the kernel (default:1000) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -split_k SplitK value (default:1) + -device Device id that will be used to run the kernel (default:0) + -init 0:random, 1:linear, 2:constant(1) (default:0) + -flush_cache Flush cache before running the kernel (default:true) +-rotating_count Rotating count (default:1000) + -quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant) + -preshuffleb Enable preshuffle of tensor B (default:false) + -group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128) ``` User need to select correct mapping of config for each quant mode: diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp new file mode 100644 index 0000000000..3786230ff0 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuant; + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp new file mode 100644 index 0000000000..cb9f8b62cf --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp new file mode 100644 index 0000000000..33ae3bc4a9 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp new file mode 100644 index 0000000000..526c35b081 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp new file mode 100644 index 0000000000..4b2a8efb14 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp new file mode 100644 index 0000000000..d9591bb588 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp new file mode 100644 index 0000000000..a35f867f5d --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" +#include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "gemm_utils.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("h", "false", "Print help message") + .insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row or Column") + .insert("b_layout", "C", "B tensor data layout - Row or Column") + .insert("bq_layout", "C", "Bq tensor data layout - Row or Column") + .insert("c_layout", "R", "C tensor data layout - Row or Column") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0: No validation, 1: Validation on CPU, 2: Validation on GPU") + .insert("prec", + "fp8", + "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " + "or bf8i4") + .insert("warmup", "50", "Number of iterations before benchmarking the kernel") + .insert("repeat", "1000", "Number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "SplitK value") + .insert("device", "0", "Device id that will be used to run the kernel") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("flush_cache", "true", "Flush cache before running the kernel") + .insert("rotating_count", "1000", "Rotating count") + .insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol") + .insert("preshuffleb", "false", "Enable preshuffle of tensor B") + .insert("group_size", + "1x1x128", + "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +auto gen_lut_key(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + std::string quant_mode = arg_parser.get_str("quant_mode"); + + std::vector params = {data_type, quant_mode}; + + if(quant_mode == "bquant") + { + std::string preshuffleb = + arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb"; + params.push_back(preshuffleb); + } + if(quant_mode != "rowcol" && quant_mode != "tensor") + { + // NOTE: rowcol and tensor pipeline do not use group size + std::string group_size_str = arg_parser.get_str("group_size"); + params.push_back(group_size_str); + } + + return hash_multiple_strings(params); +} + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut); +void quant_rowcol_instance_factory( + std::unordered_map>& lut); +void quant_tensor_instance_factory( + std::unordered_map>& lut); + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result || arg_parser.get_bool("h")) + { + arg_parser.print(); + return -1; + } + + auto device_id = arg_parser.get_int("device"); + std::cout << "Device ID: " << device_id << std::endl; + ck_tile::hip_check_error(hipSetDevice(device_id)); + + std::unordered_map> lut; + aquant_quantgrouped_instance_factory(lut); + bquant_quantgrouped_fp8_instance_factory(lut); + bquant_quantgrouped_bf8_instance_factory(lut); + bquant_quantgrouped_fp8i4_instance_factory(lut); + bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_instance_factory(lut); + quant_rowcol_instance_factory(lut); + quant_tensor_instance_factory(lut); + + auto key = gen_lut_key(arg_parser); + + if(lut.find(key) != lut.end()) + { + return lut[key](arg_parser); + } + else + { + std::cerr + << "Error: Combination of prec, quant_mode, preshuffleb, and group_size not supported." + << std::endl; + return -1; + } +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp deleted file mode 100644 index d605a2b780..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ /dev/null @@ -1,428 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -// This example demonstrates 2D block scale quantization (N×K) for BQuant -// using non-preshuffled configuration. -// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example -// This is currently done separately to avoid too verbose dispatching. - -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core/config.hpp" -#include "ck_tile/host.hpp" -#include "gemm_utils.hpp" - -template -float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) -{ - static_assert(std::is_same_v); - using ComputeDataType = std::conditional_t; - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using GemmTraits = ck_tile::TileGemmQuantTraits; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - - // This example only supports BQuant (no AQuant) - // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3 - using BaseGemmPipeline = std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; - - const ck_tile::index_t K_split = - (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr bool transpose_c = false; - - // row-col and tensor quants use the regular pipeline, A/B quants use their own - using PipelineProblem = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmRowColTensorQuantPipelineProblem, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>>; - - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrMem, // memory pipeline hardcoded - // for aquant - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - GemmConfig::TiledMMAPermuteN>>; - using Kernel = - ck_tile::QuantGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << PipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - float ave_time = 0; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - ck_tile::RotatingMemWrapper - rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString( - hipMemsetAsync(args.c_ptr, - 0, - args.M * args.N * sizeof(typename TypeConfig::CDataType), - s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - - return ave_time; - }; - return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); -} - -#include "run_gemm_quant_example.inc" - -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if((QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) && - GemmConfig::PreshuffleB) - { - throw std::runtime_error( - "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); - } - - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for A."); - } - - return 0; -} - -// Forward declaration for dispatch function -template