From 069500464de6a55b80e8341c79239b13ac8ef379 Mon Sep 17 00:00:00 2001 From: Jan Patrick Lehr Date: Mon, 2 Feb 2026 18:39:48 +0100 Subject: [PATCH 1/5] [Compiler] Addressing new compiler warnings (#3640) * [Compiler] Addressing new compiler warnings Clang enables new lifetime warnings in production and we see build errors due to this with the staging compiler. The attributes added in this PR are suggested by the compiler. However, I'm not very familiar with the code base, so the changes may be incorrect. * Update some more instances * Adds file-level ignores via clang diagnostic pragma The number of instances was large, so I decided to use file-level scope to disable the warning via pragma clang diagnostic ignored. It also showed this warning coming from the gtest dependency. For that, I did add the respective command line flag to the CMake variables. I don't know if this is acceptable or not. * This adds the remaining instances For a build on gfx90a. * fix clang format * Adding couple more instances from gfx1200 build * Fixed another few instances --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng --- cmake/gtest.cmake | 2 + example/ck_tile/01_fmha/bias.hpp | 2 +- example/ck_tile/01_fmha/mask.hpp | 2 +- example/ck_tile/01_fmha/quant.hpp | 4 ++ include/ck/host_utility/io.hpp | 5 ++- .../library/utility/convolution_parameter.hpp | 3 +- include/ck/library/utility/host_tensor.hpp | 12 ++++-- include/ck/tensor/static_tensor.hpp | 3 ++ .../multi_index_transform.hpp | 39 +++++++++++++++---- .../ck/tensor_description/tensor_adaptor.hpp | 5 ++- .../tensor_description/tensor_descriptor.hpp | 17 ++++++-- .../blockwise_gemm_pipeline_wmmaops_base.hpp | 3 ++ .../block/blockwise_gemm_pipeline_xdlops.hpp | 3 ++ .../blockwise_gemm_pipeline_xdlops_base.hpp | 4 ++ .../gpu/block/blockwise_gemm_wmma.hpp | 4 ++ .../gpu/block/blockwise_gemm_xdlops.hpp | 3 ++ .../blockwise_gemm_xdlops_skip_b_lds.hpp | 4 ++ .../gpu/device/tensor_layout.hpp | 2 +- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 4 ++ .../ck/utility/amd_wave_read_first_lane.hpp | 3 +- include/ck/utility/dtype_vector.hpp | 21 +++++----- include/ck/utility/env.hpp | 5 ++- include/ck/utility/pipeline_enum.hpp | 3 +- include/ck/utility/scheduler_enum.hpp | 3 +- include/ck/utility/static_buffer.hpp | 5 ++- include/ck/utility/tuple.hpp | 7 ++-- include/ck/wrapper/layout.hpp | 4 ++ include/ck/wrapper/tensor.hpp | 4 ++ .../core/algorithm/coordinate_transform.hpp | 4 ++ include/ck_tile/core/arch/mma/amdgcn_mma.hpp | 4 ++ include/ck_tile/core/container/map.hpp | 4 ++ include/ck_tile/core/container/tuple.hpp | 11 ++++-- include/ck_tile/core/numeric/e8m0.hpp | 4 ++ include/ck_tile/core/numeric/pk_fp4.hpp | 4 ++ .../core/tensor/static_distributed_tensor.hpp | 4 ++ .../ck_tile/core/tensor/tensor_adaptor.hpp | 4 ++ .../core/tensor/tensor_adaptor_coordinate.hpp | 4 ++ include/ck_tile/core/tensor/tensor_view.hpp | 4 ++ .../ck_tile/core/tensor/tile_distribution.hpp | 4 ++ include/ck_tile/core/utility/env.hpp | 4 ++ include/ck_tile/core/utility/functional.hpp | 3 ++ include/ck_tile/host/arg_parser.hpp | 4 ++ include/ck_tile/host/host_tensor.hpp | 4 ++ .../gemm_pipeline_ag_bg_cr_scheduler.hpp | 6 ++- profiler/src/profiler_operation_registry.hpp | 4 ++ .../position_embedding/position_embedding.cpp | 4 ++ .../gemm_multi_d/gemm_multi_d_benchmark.hpp | 4 ++ .../gemm_preshuffle_benchmark.hpp | 4 ++ .../gemm/gemm_universal/gemm_benchmark.hpp | 3 ++ .../gemm_streamk/gemm_streamk_benchmark.hpp | 4 ++ 50 files changed, 228 insertions(+), 43 deletions(-) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 993330f989..51e0359ab6 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -68,6 +68,8 @@ set(GTEST_CXX_FLAGS -Wno-deprecated -Wno-unsafe-buffer-usage -Wno-float-equal + -Wno-lifetime-safety-intra-tu-suggestions + -Wno-lifetime-safety-cross-tu-suggestions ) if(WIN32) diff --git a/example/ck_tile/01_fmha/bias.hpp b/example/ck_tile/01_fmha/bias.hpp index 33f398cc2a..b526204384 100644 --- a/example/ck_tile/01_fmha/bias.hpp +++ b/example/ck_tile/01_fmha/bias.hpp @@ -106,7 +106,7 @@ struct bias_info return info; } - friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const bias_info& bi) { bi.serialize(os); return os; diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index f85b811116..c780bf7b6b 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -191,7 +191,7 @@ struct mask_info return area; } - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const mask_info& mi) { mi.serialize(os); return os; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index feb28cba24..da588910b2 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -8,6 +8,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + // keep sync with BlockAttentionQuantScaleEnum enum class quant_scale_enum { @@ -58,3 +61,4 @@ struct quant_scale_info return os; } }; +#pragma clang diagnostic pop diff --git a/include/ck/host_utility/io.hpp b/include/ck/host_utility/io.hpp index db45199b17..22d744ff15 100644 --- a/include/ck/host_utility/io.hpp +++ b/include/ck/host_utility/io.hpp @@ -13,7 +13,7 @@ namespace ck { template -std::ostream& operator<<(std::ostream& os, const std::vector& v) +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const std::vector& v) { std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); return os; @@ -27,7 +27,8 @@ std::ostream& operator<<(std::ostream& os, const std::array& v) } template -std::ostream& operator<<(std::ostream& os, const TensorDescriptor& desc) +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const TensorDescriptor& desc) { constexpr index_t nDim = remove_cvref_t::GetNumOfDimension(); diff --git a/include/ck/library/utility/convolution_parameter.hpp b/include/ck/library/utility/convolution_parameter.hpp index 354b112040..a25002409b 100644 --- a/include/ck/library/utility/convolution_parameter.hpp +++ b/include/ck/library/utility/convolution_parameter.hpp @@ -110,4 +110,5 @@ ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]) } // namespace utils } // namespace ck -std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p); +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck::utils::conv::ConvParam& p); diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 1dda0a4863..2e95ee8cf3 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -23,10 +23,14 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" +#pragma clang diagnostic ignored "-Wlifetime-safety-cross-tu-suggestions" + namespace ck { template -std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) +std::ostream& LogRange([[clang::lifetimebound]] std::ostream& os, Range&& range, std::string delim) { bool first = true; for(auto&& v : range) @@ -580,8 +584,9 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } - friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); - friend std::ostream& operator<<(std::ostream& os, ChosenLayout tag); + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const HostTensorDescriptor& desc); + friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, ChosenLayout tag); private: std::vector mLens; @@ -1171,3 +1176,4 @@ struct Tensor }; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index 529745e3b9..c3f3bd0c91 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -4,6 +4,8 @@ #ifndef CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { // StaticTensor for Scalar @@ -270,4 +272,5 @@ __host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_elem } } // namespace ck +#pragma clang diagnostic pop #endif diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 19a4748732..5a6c335b2c 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -6,6 +6,9 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/multi_index.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { template @@ -29,7 +32,10 @@ struct PassThrough __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -305,7 +311,10 @@ struct RightPad __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -403,7 +412,10 @@ struct Embed __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1074,7 +1086,10 @@ struct Merge_v2_magic_division __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1366,7 +1381,10 @@ struct Merge_v3_division_mod __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1480,7 +1498,10 @@ struct UnMerge __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, @@ -1640,7 +1661,10 @@ struct ConvBwdDataImplicitGemmOutTransform __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; } - __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + __host__ __device__ constexpr const auto& GetUpperLengths() const [[clang::lifetimebound]] + { + return up_lengths_; + } template __host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const @@ -2236,3 +2260,4 @@ struct Xor } }; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 79c5881d48..ee8c7ed71b 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -23,7 +23,10 @@ struct TensorAdaptor { __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } - __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + __host__ __device__ constexpr const auto& GetTransforms() const [[clang::lifetimebound]] + { + return transforms_; + } __host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss() { diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 2437132d11..a237c4219d 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -7,6 +7,8 @@ #include "ck/utility/sequence_helper.hpp" #include "ck/tensor_description/multi_index_transform.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { template @@ -179,7 +181,10 @@ struct TensorDescriptor } // TODO make these private - __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + __host__ __device__ constexpr const auto& GetTransforms() const [[clang::lifetimebound]] + { + return transforms_; + } __host__ __device__ static constexpr auto GetLowerDimensionIdss() { @@ -253,9 +258,12 @@ struct TensorCoordinate __host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; } // TODO make these private - __host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; } + __host__ __device__ constexpr const auto& GetHiddenIndex() const [[clang::lifetimebound]] + { + return idx_hidden_; + } - __host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; } + __host__ __device__ auto& GetHiddenIndex() [[clang::lifetimebound]] { return idx_hidden_; } __host__ __device__ constexpr auto GetVisibleIndex() const { @@ -284,7 +292,7 @@ struct TensorCoordinateStep __host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); } // TODO make these private - __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const + __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const [[clang::lifetimebound]] { return idx_diff_visible_; } @@ -613,3 +621,4 @@ using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index f831c0f6cf..e41cf8c82d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -10,6 +10,8 @@ #include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { template @@ -1031,3 +1033,4 @@ struct BlockwiseGemmXdlops_v2 }; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 1dba7f67a1..65a326e3e7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -8,6 +8,9 @@ #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { template ::value, bool>::type = false> -std::ostream& operator<<(std::ostream& os, const Layout&) +std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, const Layout&) { os << Layout::name; return os; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 8c316bc71d..6060889c10 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -17,6 +17,9 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { // Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to @@ -1132,3 +1135,4 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight }; // namespace ck } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 44259f0601..4b64b76cc7 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -44,7 +44,8 @@ struct get_carrier<3> // replacement of host std::copy_n() template - __device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to) + __device__ static OutputIterator + copy_n(InputIterator from, Size size, [[clang::lifetimebound]] OutputIterator to) { if(0 < size) { diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index ebdbbb107d..204b199629 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck/utility/data_type.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { // vector_type @@ -116,7 +118,7 @@ struct vector_type()>> __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -136,7 +138,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -248,7 +250,7 @@ struct vector_type()>> __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -272,7 +274,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); @@ -583,7 +585,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value || is_same::value, @@ -754,7 +756,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || @@ -1427,7 +1429,7 @@ struct non_native_vector_base< } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); @@ -1627,7 +1629,7 @@ struct vector_type()>> __host__ __device__ constexpr vector_type(type v) : data_{v} {} template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value, @@ -1797,7 +1799,7 @@ struct vector_type()>> } template - __host__ __device__ constexpr auto& AsType() + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || @@ -2284,3 +2286,4 @@ using pk_i4x4_t = typename vector_type::type; using pk_i4x8_t = typename vector_type::type; } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 0cb0b4caf8..4cabd89e33 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -9,6 +9,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck { namespace internal { template @@ -188,5 +191,5 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) // environment variable to enable logging: // export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - +#pragma clang diagnostic pop #endif diff --git a/include/ck/utility/pipeline_enum.hpp b/include/ck/utility/pipeline_enum.hpp index 4421386f59..a224011a04 100644 --- a/include/ck/utility/pipeline_enum.hpp +++ b/include/ck/utility/pipeline_enum.hpp @@ -25,7 +25,8 @@ enum struct PipelineVersion } // namespace ck #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck::PipelineVersion& p) { switch(p) { diff --git a/include/ck/utility/scheduler_enum.hpp b/include/ck/utility/scheduler_enum.hpp index 0c4bfabaf3..67c5c3b50a 100644 --- a/include/ck/utility/scheduler_enum.hpp +++ b/include/ck/utility/scheduler_enum.hpp @@ -70,7 +70,8 @@ enum struct TailNumber } // namespace ck #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck::LoopScheduler& s) { switch(s) { diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index d49817eb8f..7e47da5bf8 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -5,6 +5,8 @@ #include "statically_indexed_array.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { // static buffer for scalar @@ -104,7 +106,7 @@ struct StaticBufferTupleOfVector // Set S // i is offset of S template - __host__ __device__ constexpr S& operator()(Number i) + __host__ __device__ constexpr S& operator()(Number i) [[clang::lifetimebound]] { constexpr auto i_v = i / s_per_v; constexpr auto i_s = i % s_per_v; @@ -195,3 +197,4 @@ __host__ __device__ constexpr auto make_static_buffer(LongNumber) } } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index 1657595030..16cd35e1d6 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -51,7 +51,7 @@ get_tuple_element_data_reference(const TupleElementKeyData& x) // for write access of tuple element template __host__ __device__ constexpr Data& -get_tuple_element_data_reference(TupleElementKeyData& x) +get_tuple_element_data_reference([[clang::lifetimebound]] TupleElementKeyData& x) { return x.mData; } @@ -106,6 +106,7 @@ struct TupleImpl, Xs...> : TupleElementKeyData __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey) + [[clang::lifetimebound]] { return get_tuple_element_data_reference>(*this); } @@ -147,7 +148,7 @@ struct Tuple : detail::TupleImpl - __host__ __device__ constexpr auto& At(Number) + __host__ __device__ constexpr auto& At(Number) [[clang::lifetimebound]] { static_assert(I < base::Size(), "wrong! out of range"); return base::GetElementDataByKey(detail::TupleElementKey{}); @@ -162,7 +163,7 @@ struct Tuple : detail::TupleImpl - __host__ __device__ constexpr auto& operator()(Number i) + __host__ __device__ constexpr auto& operator()(Number i) [[clang::lifetimebound]] { return At(i); } diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 334d5851db..6d99f4e5e3 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -5,6 +5,9 @@ #include "ck/wrapper/utils/layout_utils.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + // Disable from doxygen docs generation /// @cond INTERNAL namespace ck { @@ -482,3 +485,4 @@ struct Layout } // namespace wrapper } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 9f8278a357..ed7f2fa23d 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -7,6 +7,9 @@ #include "utils/tensor_partition.hpp" #include "utils/layout_utils.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + // Disable from doxygen docs generation /// @cond INTERNAL namespace ck { @@ -441,3 +444,4 @@ struct Tensor } // namespace wrapper } // namespace ck +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 732799cef8..30c93b8f00 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -11,6 +11,9 @@ #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/print.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { enum struct coord_transform_enum @@ -1776,3 +1779,4 @@ make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingA } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 4c9ef7d6ba..1eef5819bc 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -7,6 +7,9 @@ #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/ignore.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile::core::arch::mma { /** @@ -112,6 +115,7 @@ struct amdgcn_mma }; } // namespace ck_tile::core::arch::mma +#pragma clang diagnostic pop // Include the implementations #include "wmma/wmma.hpp" diff --git a/include/ck_tile/core/container/map.hpp b/include/ck_tile/core/container/map.hpp index d342235b38..8c861ceeb6 100644 --- a/include/ck_tile/core/container/map.hpp +++ b/include/ck_tile/core/container/map.hpp @@ -8,6 +8,9 @@ #include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/container/tuple.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { // naive map @@ -157,3 +160,4 @@ CK_TILE_HOST_DEVICE static void print(const map& m) } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 7f8176d5ec..11e7b1e52f 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -13,6 +13,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + #ifndef CK_TILE_TUPLE_IMPL #define CK_TILE_TUPLE_IMPL 1 #endif @@ -98,13 +101,14 @@ CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object&) } template -CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object& x) +CK_TILE_HOST_DEVICE constexpr const T& +getv([[clang::lifetimebound]] const tuple_object& x) { return x.element; } template -CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object& x) +CK_TILE_HOST_DEVICE constexpr T& getv([[clang::lifetimebound]] tuple_object& x) { return x.element; } @@ -292,7 +296,7 @@ struct tuple : impl::tuple_base, T...> //template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast&>(*this).at(i); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) const { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } - + // template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(i) = x; } template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(number{}) = x; } @@ -864,3 +868,4 @@ struct tuple_element> } \ }() #endif +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/numeric/e8m0.hpp b/include/ck_tile/core/numeric/e8m0.hpp index 41aeb8ffab..ee12524283 100644 --- a/include/ck_tile/core/numeric/e8m0.hpp +++ b/include/ck_tile/core/numeric/e8m0.hpp @@ -6,6 +6,9 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { /** @@ -100,3 +103,4 @@ CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index d74db6b336..5822e3b9bc 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -9,6 +9,9 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/mxfp_convert.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + #if defined(__gfx950__) #define CK_TILE_FP4_CVT_DEVICE 1 #else @@ -517,3 +520,4 @@ CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const #endif } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 10c7587bcb..bdd81dae07 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -14,6 +14,9 @@ #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/container/thread_buffer.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -266,3 +269,4 @@ inline constexpr bool is_similiar_distributed_tensor_v = } // namespace detail } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index 78160b800d..e6cdb66ef9 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -12,6 +12,9 @@ #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { // Transforms: Tuple @@ -950,3 +953,4 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. remove_cvref_t, \ remove_cvref_t>{trans}; \ }() +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp index 2ea76a3814..6d33bde83e 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp @@ -14,6 +14,9 @@ #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/print.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -367,3 +370,4 @@ CK_TILE_HOST_DEVICE void print(const tensor_adaptor_coordinate& coord) detail::CK_PRINT_X_<>{}(coord); } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 837f2b87a6..833a7f4413 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -14,6 +14,9 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { /* @@ -582,3 +585,4 @@ pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index f9c2aba502..aa5714e5c2 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -15,6 +15,9 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -731,3 +734,4 @@ CK_TILE_HOST_DEVICE void print(const tile_distribution #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -206,3 +209,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) // environment variable to enable logging: // export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) +#pragma clang diagnostic pop diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index aa4bfa3f15..ae79d575a8 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -10,6 +10,8 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck_tile { namespace detail { @@ -270,3 +272,4 @@ constexpr auto conditional_expr(X&& x, Y&& y) } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/host/arg_parser.hpp b/include/ck_tile/host/arg_parser.hpp index 8c45d2b175..fee7f7779b 100644 --- a/include/ck_tile/host/arg_parser.hpp +++ b/include/ck_tile/host/arg_parser.hpp @@ -13,6 +13,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { /* * a host side utility, arg parser for, either @@ -234,3 +237,4 @@ class ArgParser std::vector keys; }; } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index d26686ec37..ddeb3ad781 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -17,6 +17,9 @@ #include "ck_tile/host/joinable_thread.hpp" #include "ck_tile/host/ranges.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + namespace ck_tile { template @@ -859,3 +862,4 @@ auto get_default_stride(std::size_t row, return stride; } } // namespace ck_tile +#pragma clang diagnostic pop diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 957cf7ab8f..987704e433 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -41,7 +41,8 @@ enum struct TailNumber } // namespace ck_tile -inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck_tile::GemmPipelineScheduler& s) { switch(s) { @@ -53,7 +54,8 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch return os; } -inline std::ostream& operator<<(std::ostream& os, const ck_tile::TailNumber& s) +inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os, + const ck_tile::TailNumber& s) { switch(s) { diff --git a/profiler/src/profiler_operation_registry.hpp b/profiler/src/profiler_operation_registry.hpp index 28674554a1..fd698ee340 100644 --- a/profiler/src/profiler_operation_registry.hpp +++ b/profiler/src/profiler_operation_registry.hpp @@ -9,6 +9,9 @@ #include #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + class ProfilerOperationRegistry final { ProfilerOperationRegistry() = default; @@ -83,3 +86,4 @@ class ProfilerOperationRegistry final ::ProfilerOperationRegistry::GetInstance().Add(name, description, operation) \ _Pragma("clang diagnostic pop") // clang-format on +#pragma clang diagnostic pop diff --git a/test/position_embedding/position_embedding.cpp b/test/position_embedding/position_embedding.cpp index 134d2e5f37..689a7a799a 100644 --- a/test/position_embedding/position_embedding.cpp +++ b/test/position_embedding/position_embedding.cpp @@ -9,6 +9,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + #ifndef TEST_ALIBI_VERBOSE #define TEST_ALIBI_VERBOSE 0 #endif @@ -213,3 +216,4 @@ int main() // clang-format on return rtn ? 0 : -1; } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp index f8c196e32a..b0d8445c16 100644 --- a/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp @@ -13,6 +13,9 @@ #include "ck_tile/host.hpp" #include "gemm_multi_d_common.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-seggestions" + // Data types and Layouts are defined by the generated kernel headers // No hardcoded type definitions here to avoid conflicts @@ -230,3 +233,4 @@ void gemm_multi_d_host_reference(int verify, a_m_k, b_k_n, {d0_m_n, d1_m_n}, c_m_n_host_result); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp index 748fe581d3..41ccc4a01b 100644 --- a/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.hpp @@ -7,6 +7,9 @@ #include "ck_tile/host.hpp" #include "gemm_preshuffle_common.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + //[TODO] Move parts of this File to commons enum class Metric { @@ -234,3 +237,4 @@ void gemm_host_reference(int verify, c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp index 7c8df32ad8..11aef4c251 100644 --- a/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp +++ b/tile_engine/ops/gemm/gemm_universal/gemm_benchmark.hpp @@ -13,6 +13,8 @@ #include "ck_tile/host.hpp" #include "gemm_common.hpp" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" // Data types and Layouts are defined by the generated kernel headers // No hardcoded type definitions here to avoid conflicts @@ -240,3 +242,4 @@ void gemm_host_reference(int verify, c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); } } +#pragma clang diagnostic pop diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp index 45beb0acce..d877f174b2 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp @@ -17,6 +17,9 @@ // Data types and Layouts are defined by the generated kernel headers // No hardcoded type definitions here to avoid conflicts +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" + enum class Metric { LATENCY = 0, @@ -199,3 +202,4 @@ void gemm_host_reference(int verify, c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); } } +#pragma clang diagnostic pop From 301eb5cf083a03382f7dc69b3277038658a10b2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Lakatos?= <153429852+zsotakal@users.noreply.github.com> Date: Mon, 2 Feb 2026 22:58:11 +0100 Subject: [PATCH 2/5] Implement device grouped gemm fixed nk multi abd for rdna4 (#3619) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * device struct implementation * added xdl grouped multi abd fixed nk testing * wmma implementation fixed * avoid unnecessary device mem allocation and code cleanups * cleanup instances definitions * wmma examples added * code cleanups * fix clang format * typo and compilation fixes related to reference gemm * fix compilation error due to std::remove_cvref_t * added missing hip_check_error includes * correction to example instances * review commentes addressed * removed split-k from testing * code formatting --------- Co-authored-by: Zoltán Lakatos Co-authored-by: illsilin_amdeng --- ...grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp | 2 + .../59_grouped_gemm_multi_ABD/CMakeLists.txt | 8 + ...m_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp | 400 ++++++++ ...gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp | 396 ++++++++ ...mm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp | 27 +- ..._gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp | 33 +- ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 899 ++++++++++++++++++ ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 27 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 2 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1 + include/ck/utility/tuple_helper.hpp | 9 + .../cpu/reference_gemm_multi_abd.hpp | 194 ++++ .../gpu/grouped_gemm_multi_abd_fixed_nk.hpp | 295 +++++- .../CMakeLists.txt | 6 +- ...as_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp | 144 +++ ...as_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 144 +++ ...as_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp | 144 +++ ..._fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | 10 +- ..._fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | 10 +- ..._fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp | 2 + .../profiler/profile_gemm_multi_abd_impl.hpp | 88 +- ...e_grouped_gemm_multi_abd_fixed_nk_impl.hpp | 534 +++++++++++ test/grouped_gemm/CMakeLists.txt | 6 + .../test_grouped_gemm_multi_abd_fixed_nk.cpp | 256 +++++ 24 files changed, 3517 insertions(+), 120 deletions(-) create mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp create mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp create mode 100644 test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp index 0766373465..e6e2137bea 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp @@ -15,6 +15,8 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/host_utility/hip_check_error.hpp" + using ::ck::hip_check_error; template diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt index 4155e0a344..d7ff58705c 100644 --- a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -8,3 +8,11 @@ add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) + +add_custom_target(example_grouped_gemm_wmma_multi_abd) + +add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16 grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16) + +add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp) +add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8) \ No newline at end of file diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp new file mode 100644 index 0000000000..4eab6cfce2 --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp @@ -0,0 +1,400 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/host_utility/hip_check_error.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; + +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; + + std::vector> a0_tensors; + std::vector> b_tensors; + std::vector> b0_tensors; + std::vector> b1_tensors; + std::vector> d0_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a0_tensors.reserve(group_count); + b_tensors.reserve(group_count); + b0_tensors.reserve(group_count); + b1_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b1_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ks[i], problem_size.Ns[i], 0, B1Layout{}))); + + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b0_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b0_tensors[i].mDesc.GetElementSize() + + sizeof(B1DataType) * b1_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_2{0, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back(std::make_unique( + sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); + b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data()); + b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data()); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(&argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + b_element_op(b_tensors[i](k, n), b0_tensors[i](k, n), b1_tensors[i](k, n)); + } + } + + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + c_host_tensors[i], + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000..c494e45bfb --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp @@ -0,0 +1,396 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +#include "ck/utility/scheduler_enum.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/host_utility/hip_check_error.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; +using Scale = ck::tensor_operation::element_wise::Scale; +using AddScale = ck::tensor_operation::element_wise::BinaryWithUnaryCombinedOp; + +using A0DataType = F16; +using A1DataType = F32; +using AsDataType = ck::Tuple; +using B0DataType = F16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using A1Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); + } + }; + + std::vector> a0_tensors; + std::vector> a1_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> e_host_tensors; + std::vector> e_device_tensors; + + a0_tensors.reserve(group_count); + a1_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + e_host_tensors.reserve(group_count); + e_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, a1_tensors_device, b_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + a1_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + a1_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + e_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + e_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << e_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 2; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + a1_tensors_device.emplace_back(std::make_unique( + sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); + a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + {1, 1}, + {problem_size.stride_Bs[i]}, + {0}, + 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer(), + a1_tensors_device[i]->GetDeviceBuffer()}, + std::array{b_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i], + problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i]}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + constexpr float scale = 1.f; + auto a_element_op = AElementOp{Add{}, Scale{scale}, Scale{scale}}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(&argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k)); + } + } + + c_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(), + e_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + e_host_tensors[i], + PassThrough{}, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(64); + problem_size.Ks.push_back(64); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 28b3fa9213..dfb20777bc 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -20,6 +20,8 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/host_utility/hip_check_error.hpp" + using ::ck::DeviceMem; using ::ck::hip_check_error; using ::ck::HostTensorDescriptor; @@ -220,8 +222,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co for(int i = 0; i < group_count; i++) { - a0_tensors_device.emplace_back( - std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); b0_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); @@ -232,21 +234,12 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - c_tensors_device.emplace_back( - std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); - - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), - a0_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(A0DataType)); - - b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(), - b0_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(B0DataType)); - - b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(), - b1_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(B1DataType)); + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); + b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data()); + b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data()); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); @@ -398,7 +391,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp index 032842b9eb..82c2e17308 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -20,6 +20,8 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/host_utility/hip_check_error.hpp" + using ::ck::DeviceMem; using ::ck::hip_check_error; using ::ck::HostTensorDescriptor; @@ -47,9 +49,9 @@ using B0DataType = F16; using BsDataType = ck::Tuple; using AccDataType = F32; using CShuffleDataType = F32; -using D0DataType = F32; +using D0DataType = F16; using DsDataType = ck::Tuple; -using EDataType = F32; +using EDataType = F16; using A0Layout = Row; using A1Layout = Row; @@ -210,11 +212,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co for(int i = 0; i < group_count; i++) { - a0_tensors_device.emplace_back( - std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); - a1_tensors_device.emplace_back( - std::make_unique(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i])); + a1_tensors_device.emplace_back(std::make_unique( + sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i])); b_tensors_device.emplace_back(std::make_unique( sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); @@ -222,19 +224,12 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co d0_tensors_device.emplace_back( std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); - c_tensors_device.emplace_back( - std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); - a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), - a0_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(A0DataType)); - - a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(), - a1_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(A1DataType)); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), - b_tensors[i].mDesc.GetElementSpaceSize() * - sizeof(B0DataType)); + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data()); + a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); @@ -394,7 +389,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4: k_batch (>0)\n"); exit(0); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp new file mode 100644 index 0000000000..10e604de60 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -0,0 +1,899 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const index_t grid_size_grp, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if defined(__gfx11__) || defined(__gfx12__) + __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>()]; + + const index_t KBatch = 1; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + auto karg = gemm_desc_ptr[group_id]; + + if(karg.M == 0 || karg.N == 0 || karg.K == 0) + return; + +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) +#endif + { + + typename GridwiseGemm::Problem problem(karg.M, + karg.N, + karg.K, + karg.StrideAs, + karg.StrideBs, + karg.StrideDs, + karg.StrideE, + KBatch); + + const auto e_grid_desc_m_n = GridwiseGemm::template MakeDEGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); + constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); + constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); + + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t; + p_as_grid_(i) = static_cast(karg.p_as_grid[i]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t; + p_bs_grid_(i) = static_cast(karg.p_bs_grid[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t; + p_ds_grid_(i) = static_cast(karg.p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run( + p_as_grid_, + p_bs_grid_, + p_ds_grid_, + static_cast(karg.p_e_grid), + p_shared, + problem, + block_2_etile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } + } +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = grid_size_grp; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK + : public DeviceGroupedGemmMultiABDFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // Note: Pass multiple layout but then using only the first one + // This is to replicate xdl functionality but it should be extended + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + typename uniform_sequence_gen::type, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, + false>; + + // TODO: Block to tile mappings could potentially moved out to avoid code duplications between + // different device implementations. + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + static constexpr index_t DefaultKBatch = 1; // implementation only supports KBatch == 1 + using KernelArgument = typename GridwiseGemm::Argument; + + using GemmTransKernelArg = + GroupedGemmMultiABDKernelArgument; + + static constexpr bool CalculateHasMainKBlockLoop(const GemmTransKernelArg& karg, + index_t k_batch) + { + index_t k_grain = k_batch * KPerBlock; + index_t K_split = (karg.K + k_grain - 1) / k_batch; + return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + } + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + c_element_op, + DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + // Client is expected to manually copy the kernel arguments to the device therefore there is + // no point in setting tensor device pointers for the argument structure. + Argument(std::vector>&, + std::vector>&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op, + index_t kbatch) + : group_count_{ck::type_convert(gemm_descs.size())}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + grouped_gemm_kernel_args_dev{nullptr}, + gemm_kernel_host_args_{nullptr}, + grid_size_{0}, + k_batch_{kbatch} + { + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t fixed_N = gemm_descs[0].N_; + const index_t fixed_K = gemm_descs[0].K_; + + for(std::size_t g = 0; g < gemm_descs.size(); g++) + { + const index_t M = gemm_descs[g].M_; + const index_t N = gemm_descs[g].N_; + const index_t K = gemm_descs[g].K_; + + if(M != sum_of_m || N != fixed_N || K != fixed_K) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + // pointer + std::array p_as_grid; + std::array p_bs_grid; + std::array p_ds_grid; + + static_for<0, NumATensor, 1>{}([&](auto i) { p_as_grid[i] = nullptr; }); + static_for<0, NumBTensor, 1>{}([&](auto i) { p_bs_grid[i] = nullptr; }); + static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid[i] = nullptr; }); + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + + const index_t StrideE = gemm_descs[g].stride_C_; + + if(gemm_descs[g].stride_As_.size() != NumATensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); + } + + static_for<0, NumATensor, 1>{}( + [&](auto j) { StrideAs[j] = gemm_descs[g].stride_As_[j]; }); + + if(gemm_descs[g].stride_Bs_.size() != NumBTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); + } + + static_for<0, NumBTensor, 1>{}( + [&](auto j) { StrideBs[j] = gemm_descs[g].stride_Bs_[j]; }); + + if(gemm_descs[g].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + static_for<0, NumDTensor, 1>{}( + [&](auto j) { StrideDs[j] = gemm_descs[g].stride_Ds_[j]; }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + AverM, AverM, N, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + const index_t block_start = grid_size_; + + grid_size_ += grid_size_grp_; + + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = GemmTransKernelArg({p_as_grid, + p_bs_grid, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE}); + + gemm_desc_kernel_arg_.emplace_back(std::move(karg)); + + group_id++; + } + } + + void UpdateKBatch(index_t) {} + + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + void* gemm_kernel_host_args_; + index_t grid_size_; + index_t grid_size_grp_; + index_t sum_of_m; + + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); + } + + if(arg.k_batch_ != 1) + { + throw std::runtime_error("Split K functionality is not supported for wmma multi " + "abd fixed nk implementation."); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto e_global_memory_operation_) { + const auto kernel = kernel_grouped_gemm_wmma_fixed_nk; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + constexpr auto Set = InMemoryDataOperationEnum::Set; + ave_time = launch_kernel(integral_constant{}); + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return RunImp(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by + // vector load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + for(index_t i = 0; i < arg.group_count_; i++) + { + if(CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i], arg.k_batch_) != true) + { + supported = false; + } + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Wmma_Fixed_Nk" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetElementwiseOps(Argument& arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + arg.a_element_op_ = a_element_op; + arg.b_element_op_ = b_element_op; + arg.c_element_op_ = c_element_op; + } + + // polymorphic + void SetElementwiseOps(BaseArgument* p_arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) const override + { + + SetElementwiseOps( + *dynamic_cast(p_arg), a_element_op, b_element_op, c_element_op); + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * + sizeof(GroupedGemmMultiABDKernelArgument); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); + } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + return SetKBatch(*dynamic_cast(p_arg), k_batch); + } + + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_desc_kernel_arg_.begin(), + pArg_->gemm_desc_kernel_arg_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index fb4e01b961..36e66017c6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -605,7 +605,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK if(arg.grouped_gemm_kernel_args_dev == nullptr) { - throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr"); } float ave_time = 0; @@ -688,6 +688,11 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + // Split-K autodeduction is not supported if(arg.k_batch_ < 1) { @@ -720,6 +725,26 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK } } + for(index_t i = 0; i < arg.group_count_; i++) + { + if(get_warp_size() == 64) + { + if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + true) + { + supported = false; + } + } + else + { + if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + true) + { + supported = false; + } + } + } + return supported; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 7653724b21..311a1c0bf4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -696,7 +696,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK& ty); } +template +auto concat_tuple_of_reference(ck::Tuple& tx, ck::Tuple& ty) +{ + return ck::unpack2( + [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, + tx, + ty); +} + template __host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuple& ty) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp new file mode 100644 index 0000000000..2d766e621b --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp @@ -0,0 +1,194 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/utility/functional4.hpp" +#include "ck/utility/tuple_helper.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmMultiABD : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const AsTensorTuple& as_m_k, + const BsTensorTuple& bs_k_n, + const DsTensorTuple& ds_m_n, + Tensor& e_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : as_m_k_{as_m_k}, + bs_k_n_{bs_k_n}, + ds_m_n_{ds_m_n}, + e_m_n_{e_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const AsTensorTuple& as_m_k_; + const BsTensorTuple& bs_k_n_; + const DsTensorTuple& ds_m_n_; + Tensor& e_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmMultiABD::Argument; + + float Run(const Argument& arg) + { + static constexpr index_t NumATensor = AsTensorTuple::Size(); + static constexpr index_t NumBTensor = BsTensorTuple::Size(); + static constexpr index_t NumDTensor = DsTensorTuple::Size(); + + const int M = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[0]; + const int K = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[1]; + const int N = arg.bs_k_n_[Number<0>{}].mDesc.GetLengths()[1]; + + Tensor a_m_k({M, K}); + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + // result + auto data_refs1 = ck::tie(a_m_k(m, k)); + // inputs + auto data_refs2 = generate_tie( + [&](auto i) -> auto& { return arg.as_m_k_[Number{}](m, k); }, + Number{}); + auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); + unpack(arg.a_element_op_, data_refs); + } + } + + Tensor b_k_n({K, N}); + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + // result + auto data_refs1 = ck::tie(b_k_n(k, n)); + // inputs + auto data_refs2 = generate_tie( + [&](auto i) -> auto& { return arg.bs_k_n_[Number{}](k, n); }, + Number{}); + auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); + unpack(arg.b_element_op_, data_refs); + } + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + // compulsory + auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n)); + // optional (if multiple Ds) + auto data_refs2 = generate_tie( + [&](auto i) -> auto& { return arg.ds_m_n_[Number{}](m, n); }, + Number{}); + auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2); + unpack(arg.cde_element_op_, data_refs); + } + } + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const AsTensorTuple& as_m_k, + const BsTensorTuple& bs_k_n, + const DsTensorTuple& ds_m_n, + Tensor& e_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{as_m_k, bs_k_n, ds_m_n, e_m_n, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmMultiABD" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp index 6d97ec3a05..0879bea4ea 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp @@ -10,7 +10,6 @@ #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" namespace ck { namespace tensor_operation { @@ -21,6 +20,7 @@ using Multiply = ck::tensor_operation::element_wise::Multiply; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +#if defined(CK_USE_XDL) // RRR void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( std::vector, @@ -179,6 +179,167 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan PassThrough, Multiply, PassThrough>>>& instances); +#endif + +#if defined(CK_USE_WMMA) +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// RCR +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// CRR +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); +#endif // CK_USE // GEMM + Add + Gelu template > op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -246,6 +408,38 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA return op_ptrs; } @@ -289,6 +483,7 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -317,6 +512,38 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA return op_ptrs; } @@ -360,6 +587,7 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -388,6 +616,38 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA return op_ptrs; } @@ -431,6 +691,7 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; +#if defined(CK_USE_XDL) if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -459,6 +720,38 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + op_ptrs); + } + } +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt index 9d9a0e691c..fc60f48727 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt @@ -1,13 +1,17 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp + + device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp ) add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..a29f8513d8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,144 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +using BF16 = bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +using Multiply = element_wise::Multiply; +using PassThrough = element_wise::PassThrough; +using AddFastGelu = element_wise::AddFastGelu; +using Add = element_wise::Add; +using FastGelu = element_wise::FastGelu; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple< + // clang-format off + //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | _NWaveNPerXdl| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + Tuple, + Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + Tuple, + Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + Tuple<>, + Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + Tuple<>, + Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..2eaaaf009a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,144 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +using BF16 = bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +using Multiply = element_wise::Multiply; +using PassThrough = element_wise::PassThrough; +using AddFastGelu = element_wise::AddFastGelu; +using Add = element_wise::Add; +using FastGelu = element_wise::FastGelu; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + Tuple, + Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + Tuple, + Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + Tuple<>, + Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + Tuple<>, + Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..3320b4afa6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,144 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = Sequence; + +using BF16 = bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +using Multiply = element_wise::Multiply; +using PassThrough = element_wise::PassThrough; +using AddFastGelu = element_wise::AddFastGelu; +using Add = element_wise::Add; +using FastGelu = element_wise::FastGelu; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| + //######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | + //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple, Tuple, DsLayout, Row, Tuple, Tuple, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + Tuple, + Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + std::vector, + Tuple, + Tuple, + Row, + Tuple, + Tuple, + Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + Tuple, + Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + Tuple<>, + Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + std::vector, + Tuple, + Tuple<>, + Row, + Tuple, + Tuple, + Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + Tuple<>, + Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp index 23e3b7f511..6e72d379d0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp @@ -61,6 +61,8 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that +// portion of the instances are failing. As a workaround these have been commented out. template , S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp index 0560f159fc..5eedb8b5ee 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -61,6 +61,8 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that +// portion of the instances are failing. As a workaround these have been commented out. template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + // DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp index 95365c82e7..7d1fcb5552 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -61,6 +61,8 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that +// portion of the instances are failing. As a workaround these have been commented out. template -auto concat_tuple_of_refs(ck::Tuple& tx, ck::Tuple& ty) -{ - return ck::unpack2( - [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, - tx, - ty); -} - template c_m_n({M, N}); - using AComputeType = typename std::conditional<(NumATensor > 1), EDataType, remove_cvref_t>>::type; - Tensor a_m_k({M, K}); - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - // result - auto data_refs1 = ck::tie(a_m_k(m, k)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(a_element_op, data_refs); - } - } - using BComputeType = typename std::conditional<(NumBTensor > 1), EDataType, remove_cvref_t>>::type; - Tensor b_k_n({K, N}); - for(int k = 0; k < K; ++k) - { - for(int n = 0; n < N; ++n) - { - // result - auto data_refs1 = ck::tie(b_k_n(k, n)); - // inputs - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(b_element_op, data_refs); - } - } + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultiABD; - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + auto ref_argument = ref_gemm.MakeArgument( + as_m_k, bs_k_n, ds_m_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - // compulsory - auto data_refs1 = ck::tie(e_m_n_host_result(m, n), c_m_n(m, n)); - // optional (if multiple Ds) - auto data_refs2 = - generate_tie([&](auto i) -> auto& { return ds_m_n(Number{})(m, n); }, - Number{}); - auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); - unpack(cde_element_op, data_refs); - } - } } std::array as_device_buf; diff --git a/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp new file mode 100644 index 0000000000..eea72b324d --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp @@ -0,0 +1,534 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" + +namespace ck { +namespace profiler { + +template +auto reserveVector(std::size_t size) +{ + std::vector vec; + vec.reserve(size); + return vec; +} + +template +bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + const std::vector& StrideE, + const std::vector& kbatch_list = {1}, + int n_warmup = 1, + int n_iter = 10) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + const std::size_t group_count = Ms.size(); + const int sum_of_m = std::accumulate(Ms.begin(), Ms.end(), 0); + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + if(group_count != Ns.size() || group_count != Ks.size() || group_count != StrideAs.size() || + group_count != StrideBs.size() || (NumDTensor > 0 && group_count != StrideDs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideAs/Bs/Ds/E size\n"); + } + + auto generateInputTupleA = [&](std::size_t g) { + if constexpr(NumATensor == 0) + { + static_assert("Gemm problem should have at least 1 A tensor."); + } + else + { + using ALayout = remove_cvref_t{}, AsLayout>>; + return generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + return Tensor( + f_host_tensor_descriptor(Ms[g], Ks[g], StrideAs[g], ALayout{})); + }, + Number{}); + } + }; + auto generateInputTupleB = [&](std::size_t g) { + if constexpr(NumBTensor == 0) + { + static_assert("Gemm problem should have at least 1 B tensor."); + } + else + { + using BLayout = remove_cvref_t{}, BsLayout>>; + return generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + return Tensor( + f_host_tensor_descriptor(Ks[g], Ns[g], StrideBs[g], BLayout{})); + }, + Number{}); + } + }; + auto generateInputTupleD = [&](std::size_t g) { + if constexpr(NumDTensor == 0) + { + return ck::Tuple<>(); + } + else + { + using DLayout = remove_cvref_t{}, DsLayout>>; + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + return Tensor( + f_host_tensor_descriptor(Ms[g], Ns[g], StrideDs[g], DLayout{})); + }, + Number{}); + } + }; + + using AsTensorTuple = decltype(generateInputTupleA(0)); + using BsTensorTuple = decltype(generateInputTupleB(0)); + using DsTensorTuple = decltype(generateInputTupleD(0)); + + auto g_as_m_k = reserveVector(group_count); + auto g_bs_k_n = reserveVector(group_count); + auto g_ds_m_n = reserveVector(group_count); + auto g_e_m_n_host_results = reserveVector>(group_count); + auto g_e_m_n_device_results = reserveVector>(group_count); + + for(std::size_t g = 0; g < group_count; g++) + { + auto& as_m_k = g_as_m_k.emplace_back(generateInputTupleA(g)); + auto& bs_k_n = g_bs_k_n.emplace_back(generateInputTupleB(g)); + auto& ds_m_n = g_ds_m_n.emplace_back(generateInputTupleD(g)); + + g_e_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); + g_e_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{}))); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << g << std::endl; + static_for<0, NumATensor, 1>{}([&](auto i) { + std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; + }); + std::cout << "e_m_n: " << g_e_m_n_device_results[g].mDesc << std::endl; + } + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_m_k(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_k_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_m_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + }); + + break; + default: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_m_k(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_k_n(i).GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_m_n(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + }); + } + } + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector> g_as_device_buf(group_count); + std::vector> g_bs_device_buf(group_count); + std::vector> g_ds_device_buf(group_count); + std::vector g_e_device_buf(group_count); + + std::vector> g_as_device_view(group_count); + std::vector> g_bs_device_view(group_count); + std::vector> g_ds_device_view(group_count); + std::vector g_e_device_view(group_count); + + auto g_gemm_descs = reserveVector(group_count); + + auto grouped_gemm_kernel_args_host = + reserveVector>( + group_count); + + for(std::size_t g = 0; g < group_count; g++) + { + std::array as_stride; + std::array bs_stride; + std::array ds_stride; + + auto& as_m_k = g_as_m_k[g]; + auto& as_device_buf = g_as_device_buf[g]; + auto& as_device_view = g_as_device_view[g]; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_device_buf[i] = std::make_unique(sizeof(ADataType) * Ms[g] * Ks[g]); + as_device_buf[i]->ToDevice(as_m_k[i].mData.data()); + as_device_view[i] = as_device_buf[i]->GetDeviceBuffer(); + as_stride[i] = StrideAs[g]; + }); + + auto& bs_k_n = g_bs_k_n[g]; + auto& bs_device_buf = g_bs_device_buf[g]; + auto& bs_device_view = g_bs_device_view[g]; + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_device_buf[i] = std::make_unique(sizeof(BDataType) * Ks[g] * Ns[g]); + bs_device_buf[i]->ToDevice(bs_k_n[i].mData.data()); + bs_device_view[i] = bs_device_buf[i]->GetDeviceBuffer(); + bs_stride[i] = StrideBs[g]; + }); + + auto& ds_m_n = g_ds_m_n[g]; + auto& ds_device_buf = g_ds_device_buf[g]; + auto& ds_device_view = g_ds_device_view[g]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_device_buf[i] = std::make_unique(sizeof(DDataType) * Ms[g] * Ns[g]); + ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data()); + ds_device_view[i] = ds_device_buf[i]->GetDeviceBuffer(); + ds_stride[i] = StrideDs[g]; + }); + + g_e_device_buf[g] = std::make_unique(sizeof(EDataType) * Ms[g] * Ns[g]); + g_e_device_view[g] = g_e_device_buf[g]->GetDeviceBuffer(); + + g_gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{ + sum_of_m, + Ns[g], + Ks[g], + std::vector(as_stride.begin(), as_stride.end()), + std::vector(bs_stride.begin(), bs_stride.end()), + std::vector(ds_stride.begin(), ds_stride.end()), + StrideE[g]}); + + tensor_operation::device:: + GroupedGemmMultiABDKernelArgument + kernelArg{as_device_view, + bs_device_view, + ds_device_view, + g_e_device_view[g], + Ms[g], + Ns[g], + Ks[g], + as_stride, + bs_stride, + ds_stride, + StrideE[g]}; + + grouped_gemm_kernel_args_host.push_back(std::move(kernelArg)); + } + + using DeviceOp = tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; + + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + if(do_verification) + { + using AComputeType = + typename std::conditional<(NumATensor > 1), + EDataType, + remove_cvref_t>>::type; + + using BComputeType = + typename std::conditional<(NumBTensor > 1), + EDataType, + remove_cvref_t>>::type; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultiABD; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + for(std::size_t i = 0; i < group_count; i++) + { + auto ref_argument = ref_gemm.MakeArgument(g_as_m_k[i], + g_bs_k_n[i], + g_ds_m_n[i], + g_e_m_n_host_results[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + g_as_device_view, g_bs_device_view, g_ds_device_view, g_e_device_view, g_gemm_descs); + + if(!gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Gemm incompatible with runtime set parameters. Skipping..." + << std::endl; + } + + continue; + } + + DeviceMem gemm_workspace_dev(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem grouped_gemm_kernel_args_dev( + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_host.data(), + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + gemm_ptr->SetElementwiseOps(argument_ptr.get(), a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + for(const auto kbatch_curr : kbatch_list) + { + gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + for(std::size_t g = 0; g < group_count; g++) + { + g_e_device_buf[g]->SetZero(); + } + + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + + if(do_verification) + { + bool instance_pass = true; + for(std::size_t g = 0; g < group_count; g++) + { + g_e_device_buf[g]->FromDevice( + g_e_m_n_device_results[g].mData.data(), + g_e_m_n_device_results[g].mDesc.GetElementSize() * sizeof(EDataType)); + + instance_pass = + instance_pass && ck::utils::check_err(g_e_m_n_device_results[g], + g_e_m_n_host_results[g]); + + if(do_log) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + LogRangeAsType( + std::cout << "a[" << g << "]: ", g_as_m_k[g](i).mData, ",") + << std::endl; + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + LogRangeAsType( + std::cout << "b[" << g << "]: ", g_bs_k_n[g](i).mData, ",") + << std::endl; + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + LogRangeAsType( + std::cout << "d[" << g << "]: ", g_ds_m_n[g](i).mData, ",") + << std::endl; + }); + LogRangeAsType( + std::cout << "e_device: ", g_e_m_n_device_results[g].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "e_host : ", g_e_m_n_host_results[g].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + + if(time_kernel) + { + std::size_t flop = 0, num_btype = 0; + for(std::size_t g = 0; g < group_count; g++) + { + flop += std::size_t(2) * Ms[g] * Ns[g] * Ks[g]; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + num_btype += sizeof(ADataType) * Ms[g] * Ks[g]; + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + num_btype += sizeof(BDataType) * Ks[g] * Ns[g]; + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + num_btype += sizeof(DDataType) * Ms[g] * Ns[g]; + }); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch + << std::endl; + } + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 450950cbd6..bc79c85e59 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -18,6 +18,12 @@ if (CK_USE_XDL OR CK_USE_WMMA) target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance) add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu) endif() + + add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_multi_abd_fixed_nk) + endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp new file mode 100644 index 0000000000..610e7f2b77 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp @@ -0,0 +1,256 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/ck.hpp" +#include "ck/utility/type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp" + +#include "gtest/gtest.h" + +static ck::index_t param_mask = 0xffffff; +static ck::index_t instance_index = -1; + +using FP32 = float; +using FP16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using Add = ck::tensor_operation::element_wise::Add; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, AddFastGelu>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, + std::tuple, ck::Tuple, ck::Tuple, BF16, ck::Tuple, ck::Tuple, ck::Tuple, Row, Add>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, PassThrough>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu>, + std::tuple, ck::Tuple, ck::Tuple<>, BF16, ck::Tuple, ck::Tuple, ck::Tuple<>, Row, FastGelu> +>; +// clang-format on + +template +class TestGroupedGemmMultiABDFixedNK : public testing::Test +{ + protected: + using AsDataType = std::tuple_element_t<0, Tuple>; + using BsDataType = std::tuple_element_t<1, Tuple>; + using DsDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using AccDataType = float; + using AsLayout = std::tuple_element_t<4, Tuple>; + using BsLayout = std::tuple_element_t<5, Tuple>; + using DsLayout = std::tuple_element_t<6, Tuple>; + using ELayout = std::tuple_element_t<7, Tuple>; + using AElementOp = PassThrough; + using BElementOp = Multiply; + using CDEElementOp = std::tuple_element_t<8, Tuple>; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // integer value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + static constexpr int n_warmup_ = 0; + static constexpr int n_iter_ = 1; + + std::vector k_batches_ = {1}; + + private: + template + void SetStrides(std::vector& strides, + const std::vector& rows, + const std::vector& cols) const + { + if(std::is_same_v) + { + for(const auto c : cols) + { + strides.emplace_back(c); + } + } + else if(std::is_same_v) + { + for(const auto r : rows) + { + strides.emplace_back(r); + } + } + } + + template + void SetTupleStrides(std::vector& strides, + const std::vector& rows, + const std::vector& cols) const + { + if constexpr(Layouts::Size() > 0) + { + // As of now multi ABD implementation supports only tensors with matching layouts. + using Layout = ck::remove_cvref_t{}, Layouts>>; + SetStrides(strides, rows, cols); + } + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs = {}, + const std::vector& StrideBs = {}, + const std::vector& StrideDs = {}, + const std::vector& StrideE = {}) + { + std::vector stride_as = StrideAs; + std::vector stride_bs = StrideBs; + std::vector stride_ds = StrideDs; + std::vector stride_e = StrideE; + + if(stride_as.empty()) + { + SetTupleStrides(stride_as, Ms, Ks); + } + if(stride_bs.empty()) + { + SetTupleStrides(stride_bs, Ks, Ns); + } + if(stride_ds.empty()) + { + SetTupleStrides(stride_ds, Ms, Ns); + } + if(stride_e.empty()) + { + SetStrides(stride_e, Ms, Ns); + } + + RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_e); + } + + void RunSingle(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + const std::vector& StrideE) + { + bool pass = + ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideDs, + StrideE, + k_batches_, + n_warmup_, + n_iter_); + EXPECT_TRUE(pass); + } +}; + +TYPED_TEST_SUITE(TestGroupedGemmMultiABDFixedNK, KernelTypes); + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, TinyCases) +{ + const std::vector Ms{3, 4}; + constexpr int N = 8; + constexpr int K = 64; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, SmallCases) +{ + const std::vector Ms{3, 5, 16, 7, 8}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmMultiABDFixedNK, Regular) +{ + const std::vector Ms{64, 128, 256}; + constexpr int N = 768; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) + { + // Run with default arguments. + } + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} From 3e777217551c82a47eb9540791fb5542f2704e63 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 3 Feb 2026 02:41:53 +0400 Subject: [PATCH 3/5] feat: add split_k support for block scale gemm bquant mode. (#3653) * WIP: add splitk to bquant * feat: add support for bf8i4 and fp8i4 by calculating correct stride for packed data types * chore: remove temporary test script * fix: incorrect tile window length for splitted bq tensor window * chore: improve comments * test: add unit tests to cover bquant splitk functionality * fix: conflict resolution by renaming variables --- .../gemm_bquant_quantgrouped_bf8.cpp | 2 +- .../gemm_bquant_quantgrouped_bf8i4.cpp | 2 +- .../gemm_bquant_quantgrouped_fp8.cpp | 2 +- .../gemm_bquant_quantgrouped_fp8i4.cpp | 2 +- .../run_gemm_quant_example.inc | 183 +----------------- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 130 +++++++++++-- .../kernel/grouped_gemm_quant_kernel.hpp | 8 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 11 ++ .../test_gemm_quant_bquant_splitk_decode.cpp | 61 ++++++ .../test_gemm_quant_bquant_splitk_prefill.cpp | 64 ++++++ .../test_gemm_quant_fixtures.hpp | 16 +- 11 files changed, 273 insertions(+), 208 deletions(-) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index a95c0346cf..1520f2c591 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index d2b95d3263..a93fe15a1b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index a8c13c1b3d..39747ff0bc 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index 6576b22c03..ed18cd8890 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 540d5725dd..508f3ac8ec 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -215,11 +215,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = Kernel::BlockSize(); - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - + // Split-K validation is handled by Kernel::IsSupportedArgument + // Split-K is only supported for BQuantGrouped without preshuffle if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); @@ -661,182 +658,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } - else if(init_method == 3) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else - { - ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(2.0f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0x38)}(b_k_n); - - if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) - { - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - } - } - else if(init_method == 4) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - } - ck_tile::FillUniformDistribution{2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v || - std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } - else if(init_method == 5) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(a_m_k); - } - // Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...) - for(ck_tile::index_t row = 0; - row < static_cast(aq_tensor_ptr->get_length(0)); - ++row) - { - for(ck_tile::index_t col = 0; - col < static_cast(aq_tensor_ptr->get_length(1)); - ++col) - { - (*aq_tensor_ptr)(row, col) = static_cast(col + 1); - } - } - // std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl; - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v || - std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } else { a_m_k.SetZero(); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 21bd691b49..db86fdbeac 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -380,9 +380,18 @@ struct QuantGemmKernel __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); - const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); - const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1); + constexpr auto K1 = + GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block + const index_t K_t = amd_wave_read_first_lane( + kargs.k_batch * K1); // amount of K elements consumed if every split-K batch + // performs exactly one "unit" (K1) + const index_t KRead = amd_wave_read_first_lane( + (kargs.K + K_t - 1) / K_t * K1); // total k elements to be read in this batch + // offset not necessarily = KRead, because B can have packed elements (e.g. fp8i4) + constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + const index_t b_k_offset_elements = + amd_wave_read_first_lane(k_id * KRead / BPackedSize); if constexpr(std::is_same_v) { @@ -395,11 +404,11 @@ struct QuantGemmKernel if constexpr(std::is_same_v) { - b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B); + b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements * kargs.stride_B); } else if constexpr(std::is_same_v) { - b_k_split_offset = amd_wave_read_first_lane(k_id * KRead); + b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements); } if(k_id < static_cast(kargs.k_batch - 1)) @@ -410,10 +419,47 @@ struct QuantGemmKernel { splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1)); } + + // Compute BQ offset for BQuantGrouped mode (non-preshuffle only) + // Note: With the alignment validation in IsSupportedArgument, KRead is always + // a multiple of BQuantGroupSize::kK, so bq_k_split_offset will be correctly aligned. + if constexpr(kQuantType == QuantType::BQuantGrouped && !BPreshuffleQuant) + { + using BQuantGroupSize = remove_cvref_t; + // Compute the K offset for this batch (in terms of K elements) + const index_t k_offset = amd_wave_read_first_lane(k_id * KRead); + // Convert K offset to BQ group offset (logical offset in K/kK dimension) + bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK); + + // BQ tensor layout: + // RowMajor: [K/kK, N/kN] with stride [N/kN, 1] + // ColumnMajor: [N/kN, K/kK] with stride [K/kK, 1] + if constexpr(std::is_same_v) + { + // For RowMajor BQ, K is the row dimension + // offset = bq_group_offset * stride_BQ + const index_t stride_bq = + amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN)); + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq); + } + else if constexpr(std::is_same_v) + { + // For ColumnMajor BQ, K is the column dimension + // offset = bq_group_offset + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset); + } + } + else + { + bq_group_offset = 0; + bq_k_split_offset = 0; + } } index_t a_k_split_offset; index_t b_k_split_offset; + index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension) + index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride) index_t splitted_k; }; @@ -805,10 +851,13 @@ struct QuantGemmKernel CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr, const QuantGemmKernelArgs& kargs, + const index_t bq_group_offset, const index_t i_m, const index_t i_n) { // Step 1: Create tensor view for BQ + // Note: For split-K, the bq_ptr is already offset by bq_k_split_offset (pointer offset). + // The dimension should use the remaining K-groups from this offset position. const auto& bq_tensor_view = [&]() { if constexpr(kQuantType == QuantType::RowColQuant) { @@ -850,11 +899,12 @@ struct QuantGemmKernel "ABQuantGrouped requires ColumnMajor BQ layout"); } + using BQuantGroupSize = remove_cvref_t; if constexpr(std::is_same_v) { return make_naive_tensor_view( bq_ptr, - make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), + make_tuple(kargs.QK_B - bq_group_offset, integer_divide_ceil(kargs.N, BQuantGroupSize::kN)), make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1), number{}, @@ -865,8 +915,8 @@ struct QuantGemmKernel return make_naive_tensor_view( bq_ptr, make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), - integer_divide_ceil(kargs.K, BQuantGroupSize::kK)), - make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1), + kargs.QK_B - bq_group_offset), + make_tuple(kargs.QK_B, 1), number{}, number<1>{}); } @@ -1047,13 +1097,61 @@ struct QuantGemmKernel CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) { + // Split-K is supported for BQuantGrouped mode without preshuffle if(kargs.k_batch != 1) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + constexpr bool is_bquant_non_preshuffle = + (kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant; + if constexpr(!is_bquant_non_preshuffle) { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 ! " + "Split-K only supported for BQuantGrouped without preshuffle."); + } + return false; + } + else + { + using BQuantGroupSize = remove_cvref_t; + constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); + const index_t K_t = kargs.k_batch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // Constraint 1: KRead must align with B packing requirements. + // For packed data types, multiple K elements are stored in each storage unit. + // Split-K advances the B pointer by (KRead / BPackedSize) storage units per batch. + // If KRead is not divisible by BPackedSize, this division produces a fractional + // offset, making it impossible to start reading from a valid storage unit boundary. + if(KRead % BPackedSize != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("KRead must be a multiple of B packed size for split-K!"); + } + return false; + } + + // Constraint 2: KRead must align with quantization group boundaries. + // Each split-K batch reads KRead consecutive K elements. If KRead is not + // a multiple of BQuantGroupSize::kK, the batch will span partial quantization + // groups, requiring split access to a quantization scale. This violates the + // atomic processing requirement where each batch must work with complete groups. + if(KRead % BQuantGroupSize::kK != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Split-K batch size must be aligned with quantization group " + "size! KRead=" + + std::to_string(KRead) + + " is not divisible by BQuantGroupSize::kK=" + + std::to_string(BQuantGroupSize::kK)); + } + return false; + } } - return false; } if constexpr(std::is_same_v) @@ -1215,7 +1313,10 @@ struct QuantGemmKernel const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); - const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + // Note: Pass bq_group_offset so the tensor view dimension reflects + // the remaining K-groups from the split-K offset position. + const auto& bq_block_window = MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -1343,8 +1444,9 @@ struct QuantGemmKernel const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); - const BQDataType* bq_ptr = static_cast(kargs.bq_ptr); - CDataType* c_ptr = static_cast(kargs.c_ptr); + const BQDataType* bq_ptr = + static_cast(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index c9e725f5fd..8b77b01e2f 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -387,8 +387,8 @@ struct QuantGroupedGemmKernel Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_block_window = Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& bq_block_window = - Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = Base::MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -453,8 +453,8 @@ struct QuantGroupedGemmKernel Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); const auto& aq_block_window = Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); - const auto& bq_block_window = - Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = Base::MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); // Get hot-loop and tail configuration const index_t num_loop = __builtin_amdgcn_readfirstlane( diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 8e005d588e..2b19053f41 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -128,6 +128,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant split-K tests (no preshuffle) + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode + test_gemm_quant_bquant_splitk_decode.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill + test_gemm_quant_bquant_splitk_prefill.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant tests (with PreshuffleB) - split into 5 files add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d test_gemm_quant_bquant_preshuffle_decode_1d.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp new file mode 100644 index 0000000000..ea1a8a1fbb --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant split-K tests - Decode shape, GroupSize 128 +// Tuple format: +// clang-format off +using BQuantSplitKDecodeTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant split-K Decode +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKDecodeTypes); + +// BQuant split-K tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4×128 ✓ + this->run_test_with_validation(32, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8×128 ✓ + this->run_test_with_validation(32, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4×128 ✓ + this->run_test_with_validation(32, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test) +{ + // K=2560 for split_k=5: 2560/5=512=4×128 ✓ + // Also K must be divisible by K_Tile(256)*split_k(5)=1280 + this->run_test_with_validation(32, 128, 2560, 5); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp new file mode 100644 index 0000000000..f4f93dbbb6 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant split-K tests - Prefill shape, GroupSize 128 +// Tuple format: +// clang-format off +using BQuantSplitKPrefillTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant split-K Prefill +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKPrefillTypes); + +// BQuant split-K tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4×128 ✓ + // K must be divisible by K_Tile(128)*split_k(2)=256 + this->run_test_with_validation(128, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8×128 ✓ + // K must be divisible by K_Tile(128)*split_k(3)=384 + this->run_test_with_validation(128, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4×128 ✓ + // K must be divisible by K_Tile(128)*split_k(4)=512 + this->run_test_with_validation(128, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test) +{ + // K=1920 for split_k=5: 1920/5=384=3×128 ✓ + // K must be divisible by K_Tile(128)*split_k(5)=640 + this->run_test_with_validation(128, 128, 1920, 5); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 0033bb42a8..ca21bc69b7 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -655,7 +655,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase b_k_n_dev = b_k_n; @@ -746,12 +752,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBasetemplate calculate_rtol_atol( - K, 1, max_accumulated_value); + K, k_batch, max_accumulated_value); // Validate results bool pass = ck_tile::check_err(c_m_n_dev_result, @@ -806,7 +812,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase{})); EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N - << ", K=" << K; + << ", K=" << K << ", k_batch=" << k_batch; if(!pass) { From f2b9b3a3a65478ff84b5829c5ce2b2d2ab095905 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 3 Feb 2026 03:07:33 +0100 Subject: [PATCH 4/5] Fix path to ck tile conv fwd instance generator (#3699) * Fix path to ck tile conv fwd instance generator * fixes --- Jenkinsfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index ca7c4f1d93..80721ea6d3 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -581,7 +581,7 @@ def cmake_build(Map conf=[:]){ if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" } - if ((params.RUN_BUILDER_TESTS || params.RUN_FULL_CONV_TILE_TESTS) && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { + if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { setup_args = " -D CK_EXPERIMENTAL_BUILDER=ON " + setup_args } setup_cmd = conf.get( @@ -1428,8 +1428,8 @@ pipeline { agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" - execute_args = """ python3 ../experimental/builder/src/generate_instances.py --mode=profiler && \ - ../script/cmake-ck-dev.sh ../ gfx90a && \ + execute_args = """ python3 ../experimental/grouped_convolution_tile_instances/generate_instances.py --mode=profiler && \ + cmake .. --preset dev-gfx90a -D CK_EXPERIMENTAL_BUILDER=ON && \ make -j64 test_grouped_convnd_fwd_tile && \ ./bin/test_grouped_convnd_fwd_tile""" } From 8b56ffb6aea4dd5e3c531912ee6b2258398606ee Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 2 Feb 2026 18:25:56 -0800 Subject: [PATCH 5/5] Fix one more lifetimebound error. (#3703) * fix staging compiler errors * fix clang format --- include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp index 260ebcf4cc..35d987a79a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -63,7 +63,10 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 true> c_thread_buf_; - __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + __host__ __device__ constexpr auto& GetCThreadBuffer() [[clang::lifetimebound]] + { + return c_thread_buf_; + } __device__ static auto GetWaveIdx() {