From 0fbb3bb8c4b5860594580dd52ce5245785ceb71a Mon Sep 17 00:00:00 2001 From: andrew clark Date: Wed, 21 Jan 2026 11:00:53 -0700 Subject: [PATCH 01/42] Sanitizing URL-encoded characters from the image file name (#3622) --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index b57638caa7..f3a597e404 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -115,7 +115,7 @@ def generateAndArchiveBuildTraceVisualization(String buildTraceFileName) { // Run container to get snapshot def dockerOpts = "--cap-add=SYS_ADMIN -v \"\$(pwd)/workspace:/workspace\" -e NODE_PATH=/home/pptruser/node_modules -e BUILD_TRACE_FILE=${buildTraceFileName}" // Create unique image name by sanitizing job name - def sanitizedJobName = env.JOB_NAME.replaceAll(/[\/\\:*?"<>| ]/, '_') + def sanitizedJobName = env.JOB_NAME.replaceAll(/[\/\\:*?"<>| ]/, '_').replaceAll('%2F', '_') def architectureName = (buildTraceFileName =~ /(gfx[0-9a-zA-Z]+)/)[0][1] def imageName = "perfetto_snapshot_${sanitizedJobName}_build_${env.BUILD_NUMBER}_${architectureName}.png" sh """ From 1040d9b1f53945867d78d0bbcf03de65ee01aea3 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Wed, 21 Jan 2026 19:18:47 +0100 Subject: [PATCH 02/42] [CK_BUILDER] Replace reference conv with old ck implementation (#3604) * ck-builder: remove SPATIAL_DIM parameter from ConvTensorLayouts This information is already in the SIGNATURE, so its pointless to pass it separately. This streamlines the interface of those functions a bit. Also touches up the style of those files in general. * ck-builder: implement reference conv using old ck The old ck implementation is more featureful and better tested. * ck-builder: replace test_reference_execution reference with old ck This strips out the ck-tile gpu reference implementation completely. * ck-builder: clean up test_reference_execution - Remove unneccesary messages - Replace EXPECT_TRUE(true) with EXPECT_NO_THROW() --- .../factory/conv_bwd_weight_dl_factory.hpp | 2 +- ...onv_bwd_weight_multi_d_wmma_v3_factory.hpp | 2 +- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 2 +- ...v_bwd_weight_two_stage_wmma_v3_factory.hpp | 2 +- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 2 +- .../factory/conv_bwd_weight_wmma_factory.hpp | 2 +- .../conv_bwd_weight_wmma_v3_factory.hpp | 2 +- .../factory/conv_bwd_weight_xdl_factory.hpp | 2 +- .../conv_bwd_weight_xdl_v3_factory.hpp | 2 +- .../builder/factory/conv_fwd_dl_factory.hpp | 2 +- .../factory/conv_fwd_large_tensor_factory.hpp | 2 +- .../builder/factory/conv_fwd_v3_factory.hpp | 2 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 2 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 2 +- .../builder/factory/conv_tile_factory.hpp | 2 +- .../factory/helpers/ck/conv_tensor_layout.hpp | 47 +- .../ck_tile/conv_tile_tensor_layout.hpp | 49 +- .../builder/factory/reference_common.hpp | 118 --- .../builder/factory/reference_factory.hpp | 269 ++----- .../ck_tile/builder/testing/conv_fwd.hpp | 2 +- .../builder/testing/conv_fwd_reference.hpp | 47 +- .../builder/test/unit_conv_tensor_layout.cpp | 32 +- .../validation/test_reference_execution.cpp | 758 +++--------------- .../test_reference_instance_traits.cpp | 6 - 24 files changed, 291 insertions(+), 1067 deletions(-) delete mode 100644 experimental/builder/include/ck_tile/builder/factory/reference_common.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp index fda1659c75..e8aed8da51 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp @@ -23,7 +23,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp index b02dea9558..24dcf05f3a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index 4f6812617a..5cffdd87f0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index adf108bac4..7a391ab74f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index d887c1c1ce..6a1daf6ef4 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index 4067845291..3fa15856fa 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp index 027c8a1fba..ab941eb927 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index fbb177f333..46b1ab3965 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index 66a47c5407..11f206483f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 1d55772dd6..03989c9527 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -24,7 +24,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index b80406c37e..f7c98f244d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 74554df7e9..14266ad63f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index cb36122f7c..652b032a9b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index b3be21f1f3..79bcd84981 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -26,7 +26,7 @@ template ; + using Layouts = internal::ConvTensorLayouts; using Types = internal::ConvTensorDataTypes; using Ops = internal::ConvElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp index 35c87b61ce..b1f9136eed 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -29,7 +29,7 @@ template ; + using Layouts = internal::TileConvTensorLayouts; using Types = internal::TileConvTensorTypes; using Ops = internal::TileElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index fd6de9ae21..760106c1ae 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -172,10 +172,10 @@ struct LayoutToCK using type = ck::tensor_layout::convolution::GNDHWK; }; -template +template consteval auto TensorLayoutToCK() { - return typename LayoutToCK::type{}; + return typename LayoutToCK::type{}; } struct EmptyAuxiliaryTensorLayout @@ -183,49 +183,52 @@ struct EmptyAuxiliaryTensorLayout using type = ck::Tuple<>; }; -template +template consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence) { return ck::Tuple< - decltype(TensorLayoutToCK())...>{}; + decltype(TensorLayoutToCK())...>{}; } -template +template requires(ConvSpatialDim) struct AuxiliaryTensorLayouts { - static constexpr auto Size = AuxiliaryTensorConfigsValue.size(); - using type = decltype(GetAuxiliaryTensorLayoutTuple( + static constexpr auto Size = AUXILIARY_TENSOR_CONFIGS_VALUE.size(); + using type = decltype(GetAuxiliaryTensorLayoutTuple( std::make_index_sequence{})); }; // TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). -template - requires(HasElementwiseOpWithAuxiliaryOperands) +template + requires HasElementwiseOpWithAuxiliaryOperands consteval auto GetAuxiliaryTensorLayouts() { - return AuxiliaryTensorLayouts{}; + return AuxiliaryTensorLayouts{}; } -template - requires(!HasElementwiseOpWithAuxiliaryOperands) +template + requires(!HasElementwiseOpWithAuxiliaryOperands) consteval auto GetAuxiliaryTensorLayouts() { return EmptyAuxiliaryTensorLayout{}; } -template - requires(ConvSpatialDim && - ValidConvInputLayoutForSpatialDim && - ValidConvWeightLayoutForSpatialDim && - ValidConvOutputLayoutForSpatialDim) +template + requires ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim struct ConvTensorLayouts { - using InLayout = decltype(TensorLayoutToCK()); - using WeiLayout = decltype(TensorLayoutToCK()); - using OutLayout = decltype(TensorLayoutToCK()); - using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; + using InLayout = decltype(TensorLayoutToCK()); + using WeiLayout = decltype(TensorLayoutToCK()); + using OutLayout = decltype(TensorLayoutToCK()); + using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp index 2aaca98586..17615f84cc 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp @@ -9,10 +9,10 @@ namespace ck_tile::builder::factory::internal { using ALayout = ck_tile::tensor_layout::convolution::NWGC; -template +template struct LayoutToCKTile { - static_assert(sizeof(UnsupportedEnumValue) == 0, + static_assert(sizeof(UnsupportedEnumValue) == 0, "Unsupported layout conversion to CK."); }; @@ -152,49 +152,52 @@ struct EmptyAuxiliaryTileTensorLayout using type = ck_tile::tuple<>; }; -template +template consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence) { return ck_tile::tuple< - decltype(TensorLayoutToCKTile())...>{}; + decltype(TensorLayoutToCKTile())...>{}; } -template - requires(ConvSpatialDim) +template + requires ConvSpatialDim struct AuxiliaryTileTensorLayouts { - static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size(); - using type = decltype(GetAuxiliaryTileTensorLayoutTuple( + static constexpr auto Size = AUXILIARY_TILE_TENSOR_CONFIGS_VALUE.size(); + using type = decltype(GetAuxiliaryTileTensorLayoutTuple( std::make_index_sequence{})); }; // TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). -template - requires(HasElementwiseOpWithAuxiliaryOperands) +template + requires HasElementwiseOpWithAuxiliaryOperands consteval auto GetAuxiliaryTileTensorLayouts() { - return AuxiliaryTileTensorLayouts{}; + return AuxiliaryTileTensorLayouts{}; } -template - requires(!HasElementwiseOpWithAuxiliaryOperands) +template + requires(!HasElementwiseOpWithAuxiliaryOperands) consteval auto GetAuxiliaryTileTensorLayouts() { return EmptyAuxiliaryTileTensorLayout{}; } -template - requires(ConvSpatialDim && - ValidConvInputLayoutForSpatialDim && - ValidConvWeightLayoutForSpatialDim && - ValidConvOutputLayoutForSpatialDim) +template + requires ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim struct TileConvTensorLayouts { - using ALayout = decltype(TensorLayoutToCKTile()); - using BLayout = decltype(TensorLayoutToCKTile()); - using ELayout = decltype(TensorLayoutToCKTile()); - using DsLayout = decltype(GetAuxiliaryTileTensorLayouts())::type; + using ALayout = decltype(TensorLayoutToCKTile()); + using BLayout = decltype(TensorLayoutToCKTile()); + using ELayout = decltype(TensorLayoutToCKTile()); + using DsLayout = decltype(GetAuxiliaryTileTensorLayouts())::type; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/reference_common.hpp b/experimental/builder/include/ck_tile/builder/factory/reference_common.hpp deleted file mode 100644 index 698ed43cb9..0000000000 --- a/experimental/builder/include/ck_tile/builder/factory/reference_common.hpp +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/builder/conv_signature_concepts.hpp" -#include "ck_tile/builder/types.hpp" -#include - -namespace ck_tile::builder::factory::internal { - -// Validation helper: Ensure reference implementation only receives PassThrough elementwise ops -template -consteval void ValidateReferenceSignature() -{ - using namespace ck_tile::builder; - - // Check input elementwise operation - static_assert( - !HasTensorOp || - SIGNATURE.input.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH, - "Reference implementation does not support elementwise operations on input tensor. " - "Input operation must be PassThrough (or not specified)."); - - // Check weight elementwise operation - static_assert( - !HasTensorOp || - SIGNATURE.weight.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH, - "Reference implementation does not support elementwise operations on weight tensor. " - "Weight operation must be PassThrough (or not specified)."); - - // Check output elementwise operation - static_assert( - !HasTensorOp || - SIGNATURE.output.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH, - "Reference implementation does not support elementwise operations on output tensor. " - "Output operation must be PassThrough (or not specified)."); -} - -// Common argument structure for reference convolution implementations -// Template parameters allow different const qualifiers for each direction -template -struct ReferenceConvArgument -{ - InPtrType input_; - WeiPtrType weight_; - OutPtrType output_; - int G_, N_, K_, C_; - std::vector input_spatial_; - std::vector filter_spatial_; - std::vector output_spatial_; - std::vector strides_; - std::vector dilations_; - std::vector left_pads_; - - ReferenceConvArgument(InPtrType input, - WeiPtrType weight, - OutPtrType output, - int G, - int N, - int K, - int C, - const std::vector& input_spatial, - const std::vector& filter_spatial, - const std::vector& output_spatial, - const std::vector& strides, - const std::vector& dilations, - const std::vector& left_pads) - : input_(input), - weight_(weight), - output_(output), - G_(G), - N_(N), - K_(K), - C_(C), - input_spatial_(input_spatial), - filter_spatial_(filter_spatial), - output_spatial_(output_spatial), - strides_(strides), - dilations_(dilations), - left_pads_(left_pads) - { - } -}; - -// Common invoker structure for reference convolution implementations -// Takes a callable (lambda or function pointer) to execute the actual convolution -template -struct ReferenceConvInvoker -{ - ConvFunc conv_func_; - - explicit ReferenceConvInvoker(ConvFunc func) : conv_func_(func) {} - - float Run(const ArgumentType* arg, const StreamConfig& stream_config = StreamConfig{}) - { - (void)stream_config; // Unused for reference implementation - - conv_func_(arg->input_, - arg->weight_, - arg->output_, - arg->G_, - arg->N_, - arg->K_, - arg->C_, - arg->input_spatial_, - arg->filter_spatial_, - arg->output_spatial_, - arg->strides_, - arg->dilations_, - arg->left_pads_); - - return 0.0f; // Reference implementation doesn't track timing - } -}; - -} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp index f6fc2dbda8..32f3ff7e6e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp @@ -3,15 +3,15 @@ #pragma once -#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp" -#include "ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp" -#include "ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp" #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/types.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/reference_common.hpp" -#include "ck_tile/core.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" +#include "ck/library/utility/convolution_parameter.hpp" #include namespace ck_tile::builder::factory { @@ -22,16 +22,23 @@ template struct ReferenceFactory { - // Validate that only PassThrough elementwise operations are specified - static constexpr auto kValidation = (internal::ValidateReferenceSignature(), 0); - static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Types = internal::ConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; using InDataType = typename Types::InDataType; using WeiDataType = typename Types::WeiDataType; using OutDataType = typename Types::OutDataType; + using Layouts = factory::internal::ConvTensorLayouts; + using InLayout = typename Layouts::InLayout; + using WeiLayout = typename Layouts::WeiLayout; + using OutLayout = typename Layouts::OutLayout; + + using Ops = factory::internal::ConvElementwiseOps; + using InElementwiseOp = typename Ops::InElementwiseOp; + using WeiElementwiseOp = typename Ops::WeiElementwiseOp; + using OutElementwiseOp = typename Ops::OutElementwiseOp; + struct Instance { // Store template parameters for InstanceTraits reflection @@ -39,91 +46,57 @@ struct ReferenceFactory static constexpr auto kAlgorithm = ALGORITHM; static constexpr auto kVersion = VERSION; - // Argument and Invoker types depend on direction - // Forward: const input, const weight, mutable output - // Backward Data: mutable input, const weight, const output_grad - // Backward Weight: const input, mutable weight_grad, const output_grad - - // Use appropriate Argument type based on direction - using Argument = std::conditional_t< - ConvDirectionIsForward, - internal::ReferenceConvArgument, - std::conditional_t< - ConvDirectionIsBackwardData, - internal:: - ReferenceConvArgument, - internal:: - ReferenceConvArgument>>; - - // Invoker calls the appropriate reference implementation based on direction - struct Invoker + /// @brief Invoke reference convolution + /// + /// This is the primary overload to invoke reference convolution. As the underlying + /// function requires it, this function accepts ConvParam directly. + template + static void Run(InPtrType* input, + WeiPtrType* weight, + OutPtrType* output, + const ck::utils::conv::ConvParam& param, + InElementwiseOp in_op = InElementwiseOp{}, + WeiElementwiseOp wei_op = WeiElementwiseOp{}, + OutElementwiseOp out_op = OutElementwiseOp{}) { - float Run(const Argument* arg, const StreamConfig& stream_config = StreamConfig{}) + if constexpr(ConvDirectionIsForward) { - (void)stream_config; // Unused for reference implementation - - if constexpr(ConvDirectionIsForward) - { - ck_tile:: - naive_grouped_conv_fwd( - arg->input_, - arg->weight_, - arg->output_, - arg->G_, - arg->N_, - arg->K_, - arg->C_, - arg->input_spatial_, - arg->filter_spatial_, - arg->output_spatial_, - arg->strides_, - arg->dilations_, - arg->left_pads_); - } - else if constexpr(ConvDirectionIsBackwardData) - { - ck_tile::naive_grouped_conv_bwd_data(arg->input_, - arg->weight_, - arg->output_, - arg->G_, - arg->N_, - arg->K_, - arg->C_, - arg->input_spatial_, - arg->filter_spatial_, - arg->output_spatial_, - arg->strides_, - arg->dilations_, - arg->left_pads_); - } - else if constexpr(ConvDirectionIsBackwardWeight) - { - ck_tile::naive_grouped_conv_bwd_weight(arg->input_, - arg->weight_, - arg->output_, - arg->G_, - arg->N_, - arg->K_, - arg->C_, - arg->input_spatial_, - arg->filter_spatial_, - arg->output_spatial_, - arg->strides_, - arg->dilations_, - arg->left_pads_); - } - - return 0.0f; // Reference implementation doesn't track timing + ck::ref::naive_conv_fwd( + static_cast(input), + static_cast(weight), + static_cast(output), + param, + in_op, + wei_op, + out_op); } - }; + else if constexpr(ConvDirectionIsBackwardData) + { + ck::ref::naive_conv_bwd_data( + static_cast(input), + static_cast(weight), + static_cast(output), + param, + in_op, + wei_op, + out_op); + } + else if constexpr(ConvDirectionIsBackwardWeight) + { + ck::ref::naive_conv_bwd_weight( + static_cast(input), + static_cast(weight), + static_cast(output), + param, + in_op, + wei_op, + out_op); + } + } - // Direct Run method (simpler interface, direction-agnostic) + /// @brief Invoke reference convolution + /// + /// Convenience overload to avoid having to construct ConvParam manually. template static void Run(InPtrType* input, WeiPtrType* weight, @@ -132,68 +105,27 @@ struct ReferenceFactory int N, int K, int C, - const std::vector& input_spatial, - const std::vector& filter_spatial, - const std::vector& output_spatial, - const std::vector& strides, - const std::vector& dilations, - const std::vector& left_pads) + const std::vector& input_spatial, + const std::vector& filter_spatial, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) { - if constexpr(ConvDirectionIsForward) - { - ck_tile::naive_grouped_conv_fwd( - static_cast(input), - static_cast(weight), - static_cast(output), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - } - else if constexpr(ConvDirectionIsBackwardData) - { - ck_tile:: - naive_grouped_conv_bwd_data( - static_cast(input), - static_cast(weight), - static_cast(output), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - } - else if constexpr(ConvDirectionIsBackwardWeight) - { - ck_tile::naive_grouped_conv_bwd_weight( - static_cast(input), - static_cast(weight), - static_cast(output), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - } + Run(input, + weight, + output, + ck::utils::conv::ConvParam(SPATIAL_DIM, + G, + N, + K, + C, + filter_spatial, + input_spatial, + strides, + dilations, + left_pads, + right_pads)); } std::string GetTypeString() const @@ -209,41 +141,6 @@ struct ReferenceFactory return std::string("GPU_Reference_") + dir_str + "_" + std::to_string(SPATIAL_DIM) + "D"; } - - // Old CK interface: Create argument pointer - template - std::unique_ptr - MakeArgumentPointer(InPtrType input, - WeiPtrType weight, - OutPtrType output, - int G, - int N, - int K, - int C, - const std::vector& input_spatial, - const std::vector& filter_spatial, - const std::vector& output_spatial, - const std::vector& strides, - const std::vector& dilations, - const std::vector& left_pads) const - { - return std::make_unique(input, - weight, - output, - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - } - - // Old CK interface: Create invoker pointer - std::unique_ptr MakeInvokerPointer() const { return std::make_unique(); } }; }; diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp index dc2963edc2..51edf41cba 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp @@ -76,7 +76,7 @@ struct Args using Ops = factory::internal::ConvElementwiseOps; // TODO: We shouldn't need to call into an internal namespace here. - using Layouts = factory::internal::ConvTensorLayouts; + using Layouts = factory::internal::ConvTensorLayouts; ConvTensorLengths lengths; diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp index 6401c6a5d5..ff276f7c9c 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp @@ -32,27 +32,8 @@ concept RefConvInstance = requires(Conv& conv, const void* input, const void* weight, void* output, - int G, - int N, - int K, - int C, - std::vector dims) { - { - conv.Run(input, - weight, - output, - G, - N, - K, - C, - dims, // input_spatial - dims, // filter_spatial - dims, // output_spatial - dims, // strides - dims, // dilations - dims // left_pads - ) - }; + ck::utils::conv::ConvParam param) { + { conv.Run(input, weight, output, param) }; }; /// @brief `run()` specialization for forward convolution and the reference @@ -84,16 +65,6 @@ std::tuple run(RefConvInstance auto& conv, // Just throw for now, but regard these as TODO items that should be resolved // eventually. - // Right pads are not supported right now for some reason. - for(auto right_pad : param.input_right_pads_) - { - if(right_pad != 0) - { - std::cout << "TODO: Support right pad in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - } - if(!args.make_input_descriptor().is_packed()) { std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl; @@ -110,19 +81,7 @@ std::tuple run(RefConvInstance auto& conv, return std::make_tuple(false, 0.0f); } - conv.Run(inputs.input, - inputs.weight, - outputs.output, - param.G_, - param.N_, - param.K_, - param.C_, - param.input_spatial_lengths_, - param.filter_spatial_lengths_, - param.output_spatial_lengths_, - param.conv_filter_strides_, - param.conv_filter_dilations_, - param.input_left_pads_); + conv.Run(inputs.input, inputs.weight, outputs.output, param); return std::make_tuple(true, 0.0f); } diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 0df94d977e..6d82248e08 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -38,7 +38,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = NWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -57,7 +57,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = NGKW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -76,7 +76,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = GNWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -95,7 +95,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) .weight = {.config = {.layout = GKCX}}, .output = {.config = {.layout = NGKW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -114,7 +114,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NGKHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -133,7 +133,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NHWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -152,7 +152,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = GNHWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -171,7 +171,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) .weight = {.config = {.layout = GKCYX}}, .output = {.config = {.layout = NGKHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -190,7 +190,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) .weight = {.config = {.layout = GKCZYX}}, .output = {.config = {.layout = NGKDHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -209,7 +209,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = NDHWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -228,7 +228,7 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = GNDHWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -387,7 +387,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -414,7 +414,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -442,7 +442,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALEADD_SCALEADD_RELU}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -470,7 +470,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); @@ -497,7 +497,7 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::BIAS_BNORM_CLAMP}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v)); diff --git a/experimental/builder/test/validation/test_reference_execution.cpp b/experimental/builder/test/validation/test_reference_execution.cpp index 29f9acacd3..0aa656ae55 100644 --- a/experimental/builder/test/validation/test_reference_execution.cpp +++ b/experimental/builder/test/validation/test_reference_execution.cpp @@ -4,10 +4,10 @@ #include "ck_tile/builder/conv_builder.hpp" #include "ck_tile/builder/types.hpp" #include "impl/conv_algorithm_types.hpp" -#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp" -#include "ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp" -#include "ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp" #include "utils/ckb_conv_test_configs.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/check_err.hpp" #include @@ -53,29 +53,25 @@ TEST(ReferenceExecution, Forward_2D_FP16) // Prepare parameters for Run() std::vector input_spatial{H, W}; std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; std::vector strides{1, 1}; std::vector dilations{1, 1}; std::vector left_pads{1, 1}; + std::vector right_pads{1, 1}; RefKernel ref_kernel; - ref_kernel.Run(reinterpret_cast(in_dev.GetDeviceBuffer()), - reinterpret_cast(wei_dev.GetDeviceBuffer()), - reinterpret_cast(out_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - // If we get here, Run() worked! - std::cout << "✓ Reference Forward kernel executed!" << std::endl; - EXPECT_TRUE(true); + EXPECT_NO_THROW(ref_kernel.Run(reinterpret_cast(in_dev.GetDeviceBuffer()), + reinterpret_cast(wei_dev.GetDeviceBuffer()), + reinterpret_cast(out_dev.GetDeviceBuffer()), + G, + N, + K, + C, + input_spatial, + filter_spatial, + strides, + dilations, + left_pads, + right_pads)); } TEST(ReferenceExecution, BackwardData_2D_FP16) @@ -109,28 +105,26 @@ TEST(ReferenceExecution, BackwardData_2D_FP16) std::vector input_spatial{H, W}; std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; std::vector strides{1, 1}; std::vector dilations{1, 1}; std::vector left_pads{1, 1}; + std::vector right_pads{1, 1}; RefKernel ref_kernel; - ref_kernel.Run(reinterpret_cast(in_grad_dev.GetDeviceBuffer()), - reinterpret_cast(wei_dev.GetDeviceBuffer()), - reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - std::cout << "✓ Reference Backward Data kernel executed!" << std::endl; - EXPECT_TRUE(true); + EXPECT_NO_THROW( + ref_kernel.Run(reinterpret_cast(in_grad_dev.GetDeviceBuffer()), + reinterpret_cast(wei_dev.GetDeviceBuffer()), + reinterpret_cast(out_grad_dev.GetDeviceBuffer()), + G, + N, + K, + C, + input_spatial, + filter_spatial, + strides, + dilations, + left_pads, + right_pads)); } TEST(ReferenceExecution, BackwardWeight_2D_FP16) @@ -164,217 +158,26 @@ TEST(ReferenceExecution, BackwardWeight_2D_FP16) std::vector input_spatial{H, W}; std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; std::vector strides{1, 1}; std::vector dilations{1, 1}; std::vector left_pads{1, 1}; + std::vector right_pads{1, 1}; RefKernel ref_kernel; - ref_kernel.Run(reinterpret_cast(in_dev.GetDeviceBuffer()), - reinterpret_cast(wei_grad_dev.GetDeviceBuffer()), - reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - std::cout << "✓ Reference Backward Weight kernel executed!" << std::endl; - EXPECT_TRUE(true); -} - -// Test the old CK interface: MakeArgumentPointer + MakeInvokerPointer -TEST(ReferenceExecution, BackwardData_2D_FP16_InvokerInterface) -{ - constexpr ConvSignature sig{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_DATA, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; - - constexpr auto ref_alg = ConvAlgorithm_Reference{}; - using RefKernel = ConvBuilder::Instance; - - const int G = 1, N = 2, C = 4, K = 4, H = 3, W = 3; - - const size_t in_grad_size = G * N * C * H * W * sizeof(ck::half_t); - const size_t wei_size = G * K * C * 3 * 3 * sizeof(ck::half_t); - const size_t out_grad_size = G * N * K * H * W * sizeof(ck::half_t); - - ck::DeviceMem in_grad_dev(in_grad_size); - ck::DeviceMem wei_dev(wei_size); - ck::DeviceMem out_grad_dev(out_grad_size); - - in_grad_dev.SetZero(); - wei_dev.SetZero(); - out_grad_dev.SetZero(); - - std::vector input_spatial{H, W}; - std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; - std::vector strides{1, 1}; - std::vector dilations{1, 1}; - std::vector left_pads{1, 1}; - - RefKernel ref_kernel; - - // TEST: Use the old CK interface (MakeArgumentPointer + MakeInvokerPointer) - auto argument_ptr = ref_kernel.MakeArgumentPointer( - reinterpret_cast(in_grad_dev.GetDeviceBuffer()), - reinterpret_cast(wei_dev.GetDeviceBuffer()), - reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - auto invoker_ptr = ref_kernel.MakeInvokerPointer(); - - // Run using invoker - float time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - - std::cout << "✓ Reference Backward Data kernel executed via Invoker interface!" << std::endl; - std::cout << " (time = " << time << " ms)" << std::endl; - EXPECT_TRUE(true); -} - -// Test the old CK interface for Forward convolution -TEST(ReferenceExecution, Forward_2D_FP16_InvokerInterface) -{ - constexpr ConvSignature sig{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; - - constexpr auto ref_alg = ConvAlgorithm_Reference{}; - using RefKernel = ConvBuilder::Instance; - - const int G = 1, N = 2, C = 4, K = 4, H = 3, W = 3; - - const size_t in_size = G * N * C * H * W * sizeof(ck::half_t); - const size_t wei_size = G * K * C * 3 * 3 * sizeof(ck::half_t); - const size_t out_size = G * N * K * H * W * sizeof(ck::half_t); - - ck::DeviceMem in_dev(in_size); - ck::DeviceMem wei_dev(wei_size); - ck::DeviceMem out_dev(out_size); - - in_dev.SetZero(); - wei_dev.SetZero(); - out_dev.SetZero(); - - std::vector input_spatial{H, W}; - std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; - std::vector strides{1, 1}; - std::vector dilations{1, 1}; - std::vector left_pads{1, 1}; - - RefKernel ref_kernel; - - // TEST: Use the old CK interface (MakeArgumentPointer + MakeInvokerPointer) - auto argument_ptr = ref_kernel.MakeArgumentPointer( - reinterpret_cast(in_dev.GetDeviceBuffer()), - reinterpret_cast(wei_dev.GetDeviceBuffer()), - reinterpret_cast(out_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - auto invoker_ptr = ref_kernel.MakeInvokerPointer(); - - // Run using invoker - float time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - - std::cout << "✓ Reference Forward kernel executed via Invoker interface!" << std::endl; - std::cout << " (time = " << time << " ms)" << std::endl; - EXPECT_TRUE(true); -} - -// Test the old CK interface for Backward Weight convolution -TEST(ReferenceExecution, BackwardWeight_2D_FP16_InvokerInterface) -{ - constexpr ConvSignature sig{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_WEIGHT, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::GNHWC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::GNHWK}}}; - - constexpr auto ref_alg = ConvAlgorithm_Reference{}; - using RefKernel = ConvBuilder::Instance; - - const int G = 1, N = 2, C = 4, K = 4, H = 3, W = 3; - - const size_t in_size = G * N * C * H * W * sizeof(ck::half_t); - const size_t wei_grad_size = G * K * C * 3 * 3 * sizeof(ck::half_t); - const size_t out_grad_size = G * N * K * H * W * sizeof(ck::half_t); - - ck::DeviceMem in_dev(in_size); - ck::DeviceMem wei_grad_dev(wei_grad_size); - ck::DeviceMem out_grad_dev(out_grad_size); - - in_dev.SetZero(); - wei_grad_dev.SetZero(); - out_grad_dev.SetZero(); - - std::vector input_spatial{H, W}; - std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; - std::vector strides{1, 1}; - std::vector dilations{1, 1}; - std::vector left_pads{1, 1}; - - RefKernel ref_kernel; - - // TEST: Use the old CK interface (MakeArgumentPointer + MakeInvokerPointer) - auto argument_ptr = ref_kernel.MakeArgumentPointer( - reinterpret_cast(in_dev.GetDeviceBuffer()), - reinterpret_cast(wei_grad_dev.GetDeviceBuffer()), - reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - auto invoker_ptr = ref_kernel.MakeInvokerPointer(); - - // Run using invoker - float time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - - std::cout << "✓ Reference Backward Weight kernel executed via Invoker interface!" << std::endl; - std::cout << " (time = " << time << " ms)" << std::endl; - EXPECT_TRUE(true); + EXPECT_NO_THROW( + ref_kernel.Run(reinterpret_cast(in_dev.GetDeviceBuffer()), + reinterpret_cast(wei_grad_dev.GetDeviceBuffer()), + reinterpret_cast(out_grad_dev.GetDeviceBuffer()), + G, + N, + K, + C, + input_spatial, + filter_spatial, + strides, + dilations, + left_pads, + right_pads)); } // Test Builder Reference vs Direct GPU Reference with RANDOM INPUT @@ -430,10 +233,10 @@ TEST(ReferenceExecution, Forward_2D_FP16_Builder_vs_DirectGPUReference_Random) std::vector input_spatial{H, W}; std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; std::vector strides{1, 1}; std::vector dilations{1, 1}; std::vector left_pads{1, 1}; + std::vector right_pads{1, 1}; RefKernel builder_kernel; @@ -447,26 +250,35 @@ TEST(ReferenceExecution, Forward_2D_FP16_Builder_vs_DirectGPUReference_Random) C, input_spatial, filter_spatial, - output_spatial, strides, dilations, - left_pads); + left_pads, + right_pads); // Run 2: Direct GPU Reference (same kernel the Builder calls internally!) - ck_tile::naive_grouped_conv_fwd<2, ck::half_t, ck::half_t, ck::half_t>( + ck::ref::naive_conv_fwd( reinterpret_cast(in_dev.GetDeviceBuffer()), reinterpret_cast(wei_dev.GetDeviceBuffer()), reinterpret_cast(out_naive_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); + ck::utils::conv::ConvParam(2, + G, + N, + K, + C, + filter_spatial, + input_spatial, + strides, + dilations, + left_pads, + right_pads)); // Copy results back std::vector out_builder_result(out_elements); @@ -475,17 +287,11 @@ TEST(ReferenceExecution, Forward_2D_FP16_Builder_vs_DirectGPUReference_Random) out_naive_dev.FromDevice(out_naive_result.data()); // Compare - should be IDENTICAL (both call same kernel) - bool pass = ck::utils::check_err(out_builder_result, + EXPECT_TRUE(ck::utils::check_err(out_builder_result, out_naive_result, "Error: Builder Reference != Direct GPU Reference", 1e-6, - 1e-6); // Very tight tolerance! - - std::cout << "✓ Builder Reference vs Direct GPU Reference (RANDOM INPUT)!" << std::endl; - std::cout << " Result: " << (pass ? "IDENTICAL ✓" : "MISMATCH ✗") << std::endl; - std::cout << " This validates Builder Reference Factory is correct!" << std::endl; - - EXPECT_TRUE(pass); + 1e-6)); // Very tight tolerance! } // Test Builder Reference vs Direct GPU Reference with RANDOM INPUT - Backward Data @@ -538,10 +344,10 @@ TEST(ReferenceExecution, BackwardData_2D_FP16_Builder_vs_DirectGPUReference_Rand std::vector input_spatial{H, W}; std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; std::vector strides{1, 1}; std::vector dilations{1, 1}; std::vector left_pads{1, 1}; + std::vector right_pads{1, 1}; RefKernel builder_kernel; @@ -555,26 +361,35 @@ TEST(ReferenceExecution, BackwardData_2D_FP16_Builder_vs_DirectGPUReference_Rand C, input_spatial, filter_spatial, - output_spatial, strides, dilations, - left_pads); + left_pads, + right_pads); // Run 2: Direct GPU Reference - ck_tile::naive_grouped_conv_bwd_data<2, ck::half_t, ck::half_t, ck::half_t>( + ck::ref::naive_conv_bwd_data( reinterpret_cast(in_grad_naive_dev.GetDeviceBuffer()), reinterpret_cast(wei_dev.GetDeviceBuffer()), reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); + ck::utils::conv::ConvParam(2, + G, + N, + K, + C, + filter_spatial, + input_spatial, + strides, + dilations, + left_pads, + right_pads)); // Compare std::vector in_grad_builder_result(in_grad_elements); @@ -582,16 +397,11 @@ TEST(ReferenceExecution, BackwardData_2D_FP16_Builder_vs_DirectGPUReference_Rand in_grad_builder_dev.FromDevice(in_grad_builder_result.data()); in_grad_naive_dev.FromDevice(in_grad_naive_result.data()); - bool pass = ck::utils::check_err(in_grad_builder_result, + EXPECT_TRUE(ck::utils::check_err(in_grad_builder_result, in_grad_naive_result, "Error: Builder Backward Data != Direct GPU Reference", 1e-6, - 1e-6); - - std::cout << "✓ Builder Reference vs Direct GPU Reference (RANDOM INPUT - Backward Data)!" - << std::endl; - std::cout << " Result: " << (pass ? "IDENTICAL ✓" : "MISMATCH ✗") << std::endl; - EXPECT_TRUE(pass); + 1e-6)); } // Test Builder Reference vs Direct GPU Reference with RANDOM INPUT - Backward Weight @@ -644,10 +454,10 @@ TEST(ReferenceExecution, BackwardWeight_2D_FP16_Builder_vs_DirectGPUReference_Ra std::vector input_spatial{H, W}; std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; std::vector strides{1, 1}; std::vector dilations{1, 1}; std::vector left_pads{1, 1}; + std::vector right_pads{1, 1}; RefKernel builder_kernel; @@ -661,26 +471,35 @@ TEST(ReferenceExecution, BackwardWeight_2D_FP16_Builder_vs_DirectGPUReference_Ra C, input_spatial, filter_spatial, - output_spatial, strides, dilations, - left_pads); + left_pads, + right_pads); // Run 2: Direct GPU Reference - ck_tile::naive_grouped_conv_bwd_weight<2, ck::half_t, ck::half_t, ck::half_t>( + ck::ref::naive_conv_bwd_weight( reinterpret_cast(in_dev.GetDeviceBuffer()), reinterpret_cast(wei_grad_naive_dev.GetDeviceBuffer()), reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); + ck::utils::conv::ConvParam(2, + G, + N, + K, + C, + filter_spatial, + input_spatial, + strides, + dilations, + left_pads, + right_pads)); // Compare std::vector wei_grad_builder_result(wei_grad_elements); @@ -688,344 +507,11 @@ TEST(ReferenceExecution, BackwardWeight_2D_FP16_Builder_vs_DirectGPUReference_Ra wei_grad_builder_dev.FromDevice(wei_grad_builder_result.data()); wei_grad_naive_dev.FromDevice(wei_grad_naive_result.data()); - bool pass = ck::utils::check_err(wei_grad_builder_result, + EXPECT_TRUE(ck::utils::check_err(wei_grad_builder_result, wei_grad_naive_result, "Error: Builder Backward Weight != Direct GPU Reference", 1e-6, - 1e-6); - - std::cout << "✓ Builder Reference vs Direct GPU Reference (RANDOM INPUT - Backward Weight)!" - << std::endl; - std::cout << " Result: " << (pass ? "IDENTICAL ✓" : "MISMATCH ✗") << std::endl; - EXPECT_TRUE(pass); -} - -// Test Invoker Interface vs Direct GPU Reference with RANDOM INPUT - Forward -TEST(ReferenceExecution, Forward_2D_FP16_InvokerInterface_vs_DirectGPUReference_Random) -{ - constexpr ConvSignature sig{.spatial_dim = 2, - .direction = ConvDirection::FORWARD, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; - - constexpr auto ref_alg = ConvAlgorithm_Reference{}; - using RefKernel = ConvBuilder::Instance; - - const int G = 1, N = 2, C = 16, K = 16, H = 14, W = 14; - - const size_t in_size = G * N * C * H * W * sizeof(ck::half_t); - const size_t wei_size = G * K * C * 3 * 3 * sizeof(ck::half_t); - const size_t out_size = G * N * K * H * W * sizeof(ck::half_t); - - const size_t in_elements = G * N * C * H * W; - const size_t wei_elements = G * K * C * 3 * 3; - const size_t out_elements = G * N * K * H * W; - - std::vector in_host(in_elements); - std::vector wei_host(wei_elements); - - std::srand(12348); - for(size_t i = 0; i < in_elements; i++) - { - in_host[i] = ck::half_t(static_cast(std::rand()) / RAND_MAX * 2.0f - 1.0f); - } - for(size_t i = 0; i < wei_elements; i++) - { - wei_host[i] = ck::half_t(static_cast(std::rand()) / RAND_MAX * 2.0f - 1.0f); - } - - ck::DeviceMem in_dev(in_size); - ck::DeviceMem wei_dev(wei_size); - ck::DeviceMem out_invoker_dev(out_size); - ck::DeviceMem out_naive_dev(out_size); - - in_dev.ToDevice(in_host.data()); - wei_dev.ToDevice(wei_host.data()); - out_invoker_dev.SetZero(); - out_naive_dev.SetZero(); - - std::vector input_spatial{H, W}; - std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; - std::vector strides{1, 1}; - std::vector dilations{1, 1}; - std::vector left_pads{1, 1}; - - RefKernel builder_kernel; - - // Run 1: Builder Invoker Interface - auto argument_ptr = builder_kernel.MakeArgumentPointer( - reinterpret_cast(in_dev.GetDeviceBuffer()), - reinterpret_cast(wei_dev.GetDeviceBuffer()), - reinterpret_cast(out_invoker_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - auto invoker_ptr = builder_kernel.MakeInvokerPointer(); - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - - // Run 2: Direct GPU Reference - ck_tile::naive_grouped_conv_fwd<2, ck::half_t, ck::half_t, ck::half_t>( - reinterpret_cast(in_dev.GetDeviceBuffer()), - reinterpret_cast(wei_dev.GetDeviceBuffer()), - reinterpret_cast(out_naive_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - // Compare - std::vector out_invoker_result(out_elements); - std::vector out_naive_result(out_elements); - out_invoker_dev.FromDevice(out_invoker_result.data()); - out_naive_dev.FromDevice(out_naive_result.data()); - - bool pass = ck::utils::check_err(out_invoker_result, - out_naive_result, - "Error: Invoker Interface != Direct GPU Reference", - 1e-6, - 1e-6); - - std::cout << "✓ Invoker Interface vs Direct GPU Reference (RANDOM - Forward)!" << std::endl; - std::cout << " Result: " << (pass ? "IDENTICAL ✓" : "MISMATCH ✗") << std::endl; - EXPECT_TRUE(pass); -} - -// Test Invoker Interface vs Direct GPU Reference with RANDOM INPUT - Backward Data -TEST(ReferenceExecution, BackwardData_2D_FP16_InvokerInterface_vs_DirectGPUReference_Random) -{ - constexpr ConvSignature sig{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_DATA, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; - - constexpr auto ref_alg = ConvAlgorithm_Reference{}; - using RefKernel = ConvBuilder::Instance; - - const int G = 1, N = 2, C = 16, K = 16, H = 14, W = 14; - - const size_t in_grad_size = G * N * C * H * W * sizeof(ck::half_t); - const size_t wei_size = G * K * C * 3 * 3 * sizeof(ck::half_t); - const size_t out_grad_size = G * N * K * H * W * sizeof(ck::half_t); - - const size_t in_grad_elements = G * N * C * H * W; - const size_t wei_elements = G * K * C * 3 * 3; - const size_t out_grad_elements = G * N * K * H * W; - - std::vector wei_host(wei_elements); - std::vector out_grad_host(out_grad_elements); - - std::srand(12349); - for(size_t i = 0; i < wei_elements; i++) - { - wei_host[i] = ck::half_t(static_cast(std::rand()) / RAND_MAX * 2.0f - 1.0f); - } - for(size_t i = 0; i < out_grad_elements; i++) - { - out_grad_host[i] = ck::half_t(static_cast(std::rand()) / RAND_MAX * 2.0f - 1.0f); - } - - ck::DeviceMem in_grad_invoker_dev(in_grad_size); - ck::DeviceMem in_grad_naive_dev(in_grad_size); - ck::DeviceMem wei_dev(wei_size); - ck::DeviceMem out_grad_dev(out_grad_size); - - wei_dev.ToDevice(wei_host.data()); - out_grad_dev.ToDevice(out_grad_host.data()); - in_grad_invoker_dev.SetZero(); - in_grad_naive_dev.SetZero(); - - std::vector input_spatial{H, W}; - std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; - std::vector strides{1, 1}; - std::vector dilations{1, 1}; - std::vector left_pads{1, 1}; - - RefKernel builder_kernel; - - // Run 1: Builder Invoker Interface - auto argument_ptr = builder_kernel.MakeArgumentPointer( - reinterpret_cast(in_grad_invoker_dev.GetDeviceBuffer()), - reinterpret_cast(wei_dev.GetDeviceBuffer()), - reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - auto invoker_ptr = builder_kernel.MakeInvokerPointer(); - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - - // Run 2: Direct GPU Reference - ck_tile::naive_grouped_conv_bwd_data<2, ck::half_t, ck::half_t, ck::half_t>( - reinterpret_cast(in_grad_naive_dev.GetDeviceBuffer()), - reinterpret_cast(wei_dev.GetDeviceBuffer()), - reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - // Compare - std::vector in_grad_invoker_result(in_grad_elements); - std::vector in_grad_naive_result(in_grad_elements); - in_grad_invoker_dev.FromDevice(in_grad_invoker_result.data()); - in_grad_naive_dev.FromDevice(in_grad_naive_result.data()); - - bool pass = - ck::utils::check_err(in_grad_invoker_result, - in_grad_naive_result, - "Error: Invoker Interface != Direct GPU Reference (Backward Data)", - 1e-6, - 1e-6); - - std::cout << "✓ Invoker Interface vs Direct GPU Reference (RANDOM - Backward Data)!" - << std::endl; - std::cout << " Result: " << (pass ? "IDENTICAL ✓" : "MISMATCH ✗") << std::endl; - EXPECT_TRUE(pass); -} - -// Test Invoker Interface vs Direct GPU Reference with RANDOM INPUT - Backward Weight -TEST(ReferenceExecution, BackwardWeight_2D_FP16_InvokerInterface_vs_DirectGPUReference_Random) -{ - constexpr ConvSignature sig{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_WEIGHT, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; - - constexpr auto ref_alg = ConvAlgorithm_Reference{}; - using RefKernel = ConvBuilder::Instance; - - const int G = 1, N = 2, C = 16, K = 16, H = 14, W = 14; - - const size_t in_size = G * N * C * H * W * sizeof(ck::half_t); - const size_t wei_grad_size = G * K * C * 3 * 3 * sizeof(ck::half_t); - const size_t out_grad_size = G * N * K * H * W * sizeof(ck::half_t); - - const size_t in_elements = G * N * C * H * W; - const size_t wei_grad_elements = G * K * C * 3 * 3; - const size_t out_grad_elements = G * N * K * H * W; - - std::vector in_host(in_elements); - std::vector out_grad_host(out_grad_elements); - - std::srand(12350); - for(size_t i = 0; i < in_elements; i++) - { - in_host[i] = ck::half_t(static_cast(std::rand()) / RAND_MAX * 2.0f - 1.0f); - } - for(size_t i = 0; i < out_grad_elements; i++) - { - out_grad_host[i] = ck::half_t(static_cast(std::rand()) / RAND_MAX * 2.0f - 1.0f); - } - - ck::DeviceMem in_dev(in_size); - ck::DeviceMem wei_grad_invoker_dev(wei_grad_size); - ck::DeviceMem wei_grad_naive_dev(wei_grad_size); - ck::DeviceMem out_grad_dev(out_grad_size); - - in_dev.ToDevice(in_host.data()); - out_grad_dev.ToDevice(out_grad_host.data()); - wei_grad_invoker_dev.SetZero(); - wei_grad_naive_dev.SetZero(); - - std::vector input_spatial{H, W}; - std::vector filter_spatial{3, 3}; - std::vector output_spatial{H, W}; - std::vector strides{1, 1}; - std::vector dilations{1, 1}; - std::vector left_pads{1, 1}; - - RefKernel builder_kernel; - - // Run 1: Builder Invoker Interface - auto argument_ptr = builder_kernel.MakeArgumentPointer( - reinterpret_cast(in_dev.GetDeviceBuffer()), - reinterpret_cast(wei_grad_invoker_dev.GetDeviceBuffer()), - reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - auto invoker_ptr = builder_kernel.MakeInvokerPointer(); - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - - // Run 2: Direct GPU Reference - ck_tile::naive_grouped_conv_bwd_weight<2, ck::half_t, ck::half_t, ck::half_t>( - reinterpret_cast(in_dev.GetDeviceBuffer()), - reinterpret_cast(wei_grad_naive_dev.GetDeviceBuffer()), - reinterpret_cast(out_grad_dev.GetDeviceBuffer()), - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); - - // Compare - std::vector wei_grad_invoker_result(wei_grad_elements); - std::vector wei_grad_naive_result(wei_grad_elements); - wei_grad_invoker_dev.FromDevice(wei_grad_invoker_result.data()); - wei_grad_naive_dev.FromDevice(wei_grad_naive_result.data()); - - bool pass = - ck::utils::check_err(wei_grad_invoker_result, - wei_grad_naive_result, - "Error: Invoker Interface != Direct GPU Reference (Backward Weight)", - 1e-6, - 1e-6); - - std::cout << "✓ Invoker Interface vs Direct GPU Reference (RANDOM - Backward Weight)!" - << std::endl; - std::cout << " Result: " << (pass ? "IDENTICAL ✓" : "MISMATCH ✗") << std::endl; - EXPECT_TRUE(pass); + 1e-6)); } } // namespace diff --git a/experimental/builder/test/validation/test_reference_instance_traits.cpp b/experimental/builder/test/validation/test_reference_instance_traits.cpp index 3e79d51ac7..154a0693e4 100644 --- a/experimental/builder/test/validation/test_reference_instance_traits.cpp +++ b/experimental/builder/test/validation/test_reference_instance_traits.cpp @@ -62,8 +62,6 @@ TEST(ReferenceInstanceTraits, Forward_2D_FP16) // Verify instance_string() - now includes data type and layouts! std::string instance_str = Traits::instance_string(); EXPECT_EQ(instance_str, "GPU_Reference_Forward_2D_fp16_NHWGC_GKYXC_NHWGK"); - - std::cout << "✓ Forward InstanceTraits validated: " << instance_str << std::endl; } TEST(ReferenceInstanceTraits, BackwardData_2D_FP16) @@ -86,8 +84,6 @@ TEST(ReferenceInstanceTraits, BackwardData_2D_FP16) std::string instance_str = Traits::instance_string(); EXPECT_EQ(instance_str, "GPU_Reference_BackwardData_2D_fp16_NHWGC_GKYXC_NHWGK"); - - std::cout << "✓ Backward Data InstanceTraits validated: " << instance_str << std::endl; } TEST(ReferenceInstanceTraits, BackwardWeight_2D_FP16) @@ -110,8 +106,6 @@ TEST(ReferenceInstanceTraits, BackwardWeight_2D_FP16) std::string instance_str = Traits::instance_string(); EXPECT_EQ(instance_str, "GPU_Reference_BackwardWeight_2D_fp16_NHWGC_GKYXC_NHWGK"); - - std::cout << "✓ Backward Weight InstanceTraits validated: " << instance_str << std::endl; } } // namespace From 4c2c18ef486641d1493f3dc272a1e0e079676308 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Kulikowski?= Date: Thu, 22 Jan 2026 03:10:16 +0100 Subject: [PATCH 03/42] [CK][Examples] Extending support for rdna3/4 part 4: (#3264) * [CK][Examples] Extending support for rdna3/4 part 4: -example_gemm_xdl_streamk -example_gemm_xdl_fp16_fp8_v3 -example_gemm_xdl_fp16_v3 Signed-off-by: Michal Kulikowski * [CK][Examples] Revert example\01_gemm\gemm_xdl_streamk parameters change. Signed-off-by: Michal Kulikowski --------- Signed-off-by: Michal Kulikowski Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp | 4 ++-- example/01_gemm/gemm_xdl_fp16_v3.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp index 84ea93ad43..d93e7c9177 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp @@ -28,10 +28,10 @@ using DeviceGemmV2Instance = ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 64, - 16, 16, + 32, 32, 256, 8, 16, 16, 16, - 1, 1, + 2, 2, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, diff --git a/example/01_gemm/gemm_xdl_fp16_v3.cpp b/example/01_gemm/gemm_xdl_fp16_v3.cpp index e696daf8f0..99691064e2 100644 --- a/example/01_gemm/gemm_xdl_fp16_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v3.cpp @@ -28,10 +28,10 @@ using DeviceGemmV2Instance = ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, - 16, 16, + 32, 32, 256, 8, 8, 16, 16, - 1, 1, + 2, 2, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, From dd0b4294afcf188f4a9154b7eea19f8e786c9539 Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 22 Jan 2026 12:58:26 +0800 Subject: [PATCH 04/42] Fp8 block scale quantization for fmha fwd (#3330) * add block scale parameters to kernel * add block scale to kernel * add smoke test * format * Revert "format" This reverts commit 356c3c970664af68a04e2694da7e270b8c8338bf. * only format my code * format py * fix auto not allowd in function prototype * change instance tttt to ttff * fix structured binding issue * change s_acc elementwise op * async pipeline add block scale * add quantation P using shift exp2 * precompute (m - shift) once per row * change blk scale seqstrt ptr name * fix some name * fix for deduction guide * fix some comments * add P scale to qr_ksvs_pipeline * add comment to idx_identity * change the method of calculating descale block index * unify naming style: use block_scale_ as name prefix * unify naming style * update the CHANGELOG.md * Add FP8 block scale quantization support for FMHA forward kernel --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Po Yen Chen --- CHANGELOG.md | 1 + .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 7 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 26 ++ example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 230 +++++++++++++---- example/ck_tile/01_fmha/quant.hpp | 7 + .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 5 +- include/ck_tile/core/numeric/math.hpp | 7 + include/ck_tile/core/utility/functional.hpp | 12 + .../host/reference/reference_batched_gemm.hpp | 40 +++ .../block_attention_quant_scale_enum.hpp | 6 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 242 +++++++++++++++++- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 83 +++++- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 83 +++++- 14 files changed, 667 insertions(+), 84 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3a257e464..dfb50e9bdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. +* Added FP8 block scale quantization for FMHA forward kernel. ### Changed diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index a3cfe2622a..cac6671ca5 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -77,11 +77,13 @@ def get_mask_cpp_check_expr(mask: str) -> str: QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", + "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", } QSCALE_CHECK_MAP = { "no": "quant_scale_enum::no_scale", "pertensor": "quant_scale_enum::pertensor", + "blockscale": "quant_scale_enum::blockscale", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b..ed86f57232 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1018,7 +1018,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): # no need lse/dropout kernels for logits, qscale, mask, bias, sink in itertools.product( ["t", "f"], - ["no", "pertensor"], + ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"], ["f", "t"], @@ -1146,7 +1146,10 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( - ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] + ["f"], + ["no", "pertensor", "blockscale"], + get_mask_map(mask_impl).keys(), + ["no"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index fdd720fd75..aedbb0e17c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,6 +230,8 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) + const void* block_scale_seqstart_q_ptr; + const void* block_scale_seqstart_k_ptr; const void* sink_ptr; ck_tile::index_t seqlen_q; @@ -257,6 +259,9 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -264,6 +269,9 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; @@ -276,6 +284,9 @@ struct fmha_fwd_args std::variant, std::pair> drop_seed_offset; + + ck_tile::index_t block_scale_size_q; + ck_tile::index_t block_scale_size_kv; }; struct fmha_fwd_pagedkv_args @@ -615,6 +626,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, + args.block_scale_seqstart_q_ptr, + args.block_scale_seqstart_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, @@ -634,6 +647,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -642,6 +658,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_size_q, + args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); @@ -679,6 +697,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, @@ -686,6 +707,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, + args.batch_stride_q_descale, + args.batch_stride_k_descale, + args.batch_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -693,6 +717,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_size_q, + args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 0c988b2acc..b6287245a0 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -210,6 +210,11 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { + // Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the + // compute block size + constexpr ck_tile::index_t block_scale_size_q_ = 128; + constexpr ck_tile::index_t block_scale_size_kv_ = 128; + const std::string data_type = []() { if constexpr(std::is_same_v) return "fp32"; @@ -471,7 +476,11 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - auto max_seqlen_k = std::numeric_limits::min(); + size_t i_block_scale_q = 0; + size_t i_block_scale_k = 0; + std::vector block_scale_seqstart_q_host = {0}; + std::vector block_scale_seqstart_k_host = {0}; + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -487,6 +496,10 @@ fwd_result fmha_fwd_run(mode_enum mode, { max_seqlen_k = real_seqlen_k; } + i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); + i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_); + block_scale_seqstart_q_host.push_back(i_block_scale_q); + block_scale_seqstart_k_host.push_back(i_block_scale_k); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -548,6 +561,15 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqstart_k_with_padding_host.back() : seqstart_k_host.back())); + const ck_tile::index_t num_block_scale_q = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_) + : i_block_scale_q; + const ck_tile::index_t num_block_scale_kv = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_) + : i_block_scale_k; + ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor sink_host({nhead}); @@ -599,9 +621,18 @@ fwd_result fmha_fwd_run(mode_enum mode, : std::array{1, 1, 1, 1, 1}); // TODO - change the tensor length for different quant scale - ck_tile::HostTensor q_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor k_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor v_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor q_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead, num_block_scale_q} + : std::array{1, 1, 1}); + ck_tile::HostTensor k_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead_k, num_block_scale_kv} + : std::array{1, 1, 1}); + ck_tile::HostTensor v_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead_k, num_block_scale_kv} + : std::array{1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -717,6 +748,12 @@ fwd_result fmha_fwd_run(mode_enum mode, k_descale_host(0) = qkv_max / k_dtype_max; v_descale_host(0) = qkv_max / v_dtype_max; } + else if(qscale.type == quant_scale_enum::blockscale) + { + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); + } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); @@ -737,6 +774,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() * + sizeof(int32_t)); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); @@ -782,6 +823,8 @@ fwd_result fmha_fwd_run(mode_enum mode, q_descale_buf.ToDevice(q_descale_host.data()); k_descale_buf.ToDevice(k_descale_host.data()); v_descale_buf.ToDevice(v_descale_host.data()); + block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data()); + block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); // Keep logical starts in seqstart_k; pass padded K via separate pointer seqstart_k.ToDevice(seqstart_k_host.data()); @@ -975,11 +1018,14 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); - const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q; + const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv; + const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv; // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = @@ -997,6 +1043,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead; + const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k; + const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k; // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); @@ -1084,9 +1133,39 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { - args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); - args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); - args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + if(qscale.type == quant_scale_enum::blockscale) + { + args.q_descale_ptr = + reinterpret_cast(q_descale_buf.GetDeviceBuffer()); + args.k_descale_ptr = + reinterpret_cast(k_descale_buf.GetDeviceBuffer()); + args.v_descale_ptr = + reinterpret_cast(v_descale_buf.GetDeviceBuffer()); + + args.block_scale_seqstart_q_ptr = + (mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer() + : nullptr); + args.block_scale_seqstart_k_ptr = + (mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer() + : nullptr); + + args.nhead_stride_q_descale = nhead_stride_q_descale; + args.nhead_stride_k_descale = nhead_stride_k_descale; + args.nhead_stride_v_descale = nhead_stride_v_descale; + + args.batch_stride_q_descale = batch_stride_q_descale; + args.batch_stride_k_descale = batch_stride_k_descale; + args.batch_stride_v_descale = batch_stride_v_descale; + + args.block_scale_size_q = block_scale_size_q_; + args.block_scale_size_kv = block_scale_size_kv_; + } + else + { + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + } args.rand_val_ptr = randval_buf.GetDeviceBuffer(); @@ -1589,14 +1668,42 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif // reference - ck_tile:: - reference_batched_gemm( + if(qscale.type == quant_scale_enum::blockscale) + { + const ck_tile::index_t q_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb]; + const ck_tile::index_t k_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; + ck_tile::reference_batched_quant_gemm( q_host_ref, k_host_ref, s_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale_s_host)); + ck_tile::idx_identity{}, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return value * scale_s * + q_descale_host(b_idx, + std::get<0>(idx), + q_offset + std::get<1>(idx) / block_scale_size_q_) * + k_descale_host(b_idx, + std::get<0>(idx) / nr, + k_offset + std::get<2>(idx) / block_scale_size_kv_); + }); + } + else + { + ck_tile:: + reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s_host)); + } if(0.f < logits_soft_cap) { @@ -1794,13 +1901,35 @@ fwd_result fmha_fwd_run(mode_enum mode, } } - ck_tile::reference_batched_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - oacc_element_func); + if(qscale.type == quant_scale_enum::blockscale) + { + const ck_tile::index_t v_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; + ck_tile:: + reference_batched_quant_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return ck_tile::type_convert(value) * + v_descale_host(b_idx, + std::get<0>(idx) / nr, + v_offset + + std::get<2>(idx) / block_scale_size_kv_); + }, + ck_tile::idx_identity{}); + } + else + { + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + } ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off @@ -1808,7 +1937,6 @@ fwd_result fmha_fwd_run(mode_enum mode, if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on - auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err(o_host_result, o_host_ref, @@ -1866,31 +1994,33 @@ fwd_result fmha_fwd_run(mode_enum mode, if(json) { - dump_fmha_fwd_json_results(*json, - data_type, - mode == mode_enum::batch ? "batch" : "group", - io_layout(i_perm, o_perm), - batch, - nhead, - nhead_k, - seqlen_qs[0], - seqlen_ks[0], - seqlen_kpads[0], - hdim_q, - hdim_v, - scale_s, - p_drop, - lse, - qscale.type == quant_scale_enum::no_scale ? "no_scale" - : "pertensor", - bias.type == bias_enum::elementwise_bias - ? "elementwise_bias" - : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - is_v_rowmajor ? "r" : "c", - pass, - ave_time, - tflops, - gb_per_sec); + dump_fmha_fwd_json_results( + *json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + qscale.type == quant_scale_enum::no_scale + ? "no_scale" + : (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"), + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); } return pass ? fwd_result::success : fwd_result::failure; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index 59d4ac1707..feb28cba24 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -13,6 +13,7 @@ enum class quant_scale_enum { no_scale = 0, pertensor = 1, + blockscale, }; struct quant_scale_info @@ -25,6 +26,8 @@ struct quant_scale_info os << "n"; else if(type == quant_scale_enum::pertensor) os << "pt"; + else if(type == quant_scale_enum::blockscale) + os << "bs"; } static quant_scale_info decode(std::string str) @@ -38,6 +41,10 @@ struct quant_scale_info { info.type = quant_scale_enum::pertensor; } + else if(str == "bs" || str == "2") + { + info.type = quant_scale_enum::blockscale; + } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 596542eb9d..227f26c8f3 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -95,10 +95,11 @@ run_fp8bf16_tests() { for perm in 0 1 ; do for b in 1 2 ; do for hdim in 64 128 256 ; do + for scale in 1 2; do - $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS - done ; done ; done + done ; done ; done ; done } run_fp8fp32_tests() { diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 96e76f669d..a46ae509dd 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -37,6 +37,13 @@ struct scales return lhs_ * rhs; } + template + CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const + { + auto new_scale = lhs_ * other; + return scales>(new_scale); + } + private: Scale lhs_; }; diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index 898d21574e..aa4bfa3f15 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -119,6 +119,18 @@ struct identity } }; +// Similar to identity, but takes an additional index parameter as the first argument. +// The index is ignored and only the second argument (value) is forwarded. +// Useful for indexed element-wise operations where the functor signature requires an index. +struct idx_identity +{ + template + CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept + { + return std::forward(arg); + } +}; + namespace detail { // RemainLengths: sequence<...> diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index 63f13b1b16..d742426740 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -47,4 +47,44 @@ CK_TILE_HOST void reference_batched_gemm(const HostTensor& a_b_m_k, make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } +template +CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, + const HostTensor& b_b_n_k, + HostTensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const int N = b_b_n_k.mDesc.get_lengths()[1]; + const int K = b_b_n_k.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + AccDataType v_a = ck_tile::type_convert( + a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k))); + AccDataType v_b = ck_tile::type_convert( + b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k))); + + v_acc += v_a * v_b; + } + + c_b_m_n(batch, m, n) = ck_tile::type_convert( + acc_element_op(std::make_tuple(batch, m, n), v_acc)); + } + }; + + make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 3755a2bc71..7e0f704bef 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -12,6 +12,7 @@ enum class BlockAttentionQuantScaleEnum { NO_SCALE = 0, PERTENSOR = 1, + BLOCKSCALE, }; template @@ -27,5 +28,10 @@ struct BlockAttentionQuantScaleEnumToStr +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "blockscale"; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index adbedc5259..0039c57cfc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -168,6 +168,29 @@ struct FmhaFwdKernel const void* v_descale_ptr = nullptr; }; + struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs + { + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; + + ck_tile::index_t block_scale_size_q; + ck_tile::index_t block_scale_size_kv; + }; + + struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; + }; + + struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + const int32_t* block_scale_seqstart_q_ptr; + const int32_t* block_scale_seqstart_k_ptr; + }; + struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -243,9 +266,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t> { @@ -269,9 +295,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -328,6 +357,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -335,6 +367,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -343,6 +378,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -413,6 +450,23 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.batch_stride_q_descale = batch_stride_q_descale; + kargs.batch_stride_k_descale = batch_stride_k_descale; + kargs.batch_stride_v_descale = batch_stride_v_descale; + + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_kv = block_scale_size_kv; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -478,6 +532,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -485,6 +542,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -492,6 +552,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -528,6 +590,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -535,6 +600,9 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -542,6 +610,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -581,6 +651,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -588,6 +661,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -595,6 +671,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -631,6 +709,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -638,6 +719,9 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -645,6 +729,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -666,6 +752,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -685,6 +773,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -694,6 +785,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -763,6 +856,24 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_kv = block_scale_size_kv; + + kargs.block_scale_seqstart_q_ptr = + reinterpret_cast(block_scale_seqstart_q_ptr); + kargs.block_scale_seqstart_k_ptr = + reinterpret_cast(block_scale_seqstart_k_ptr); + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -814,6 +925,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -833,6 +946,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -841,6 +957,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -860,6 +978,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, + block_scale_seqstart_q_ptr, + block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -879,6 +999,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -887,6 +1010,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -909,6 +1034,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -928,6 +1055,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -936,6 +1066,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -955,6 +1087,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, + block_scale_seqstart_q_ptr, + block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -974,6 +1108,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -982,6 +1119,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -1111,13 +1250,16 @@ struct FmhaFwdKernel const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + long_index_t batch_offset_q_descale = 0; + long_index_t batch_offset_k_descale = 0; + long_index_t batch_offset_v_descale = 0; const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s @@ -1153,6 +1295,14 @@ struct FmhaFwdKernel { batch_offset_randval = query_start * kargs.stride_randval; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch]; + const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch]; + batch_offset_q_descale = bquery_start; + batch_offset_k_descale = bkey_start; + batch_offset_v_descale = bkey_start; + } batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) @@ -1220,6 +1370,15 @@ struct FmhaFwdKernel batch_offset_randval = static_cast(i_batch) * kargs.batch_stride_randval; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + batch_offset_q_descale = + static_cast(i_batch) * kargs.batch_stride_q_descale; + batch_offset_k_descale = + static_cast(i_batch) * kargs.batch_stride_k_descale; + batch_offset_v_descale = + static_cast(i_batch) * kargs.batch_stride_v_descale; + } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; // If cumulative seqlen pointers are provided, override per-batch effective lengths @@ -1540,7 +1699,8 @@ struct FmhaFwdKernel }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - auto o_acc_tile = [&]() { + + auto o_acc_tile = [&, i_nhead_ = i_nhead]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline @@ -1581,8 +1741,62 @@ struct FmhaFwdKernel block_indices, smem_ptr, dropout, + nullptr, + nullptr, + 1, sink_value); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const float* q_descale_ptr = + reinterpret_cast(kargs.q_descale_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_q_descale + + batch_offset_q_descale; + const float* k_descale_ptr = + reinterpret_cast(kargs.k_descale_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k_descale + + batch_offset_k_descale; + const float* v_descale_ptr = + reinterpret_cast(kargs.v_descale_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v_descale + + batch_offset_v_descale; + + size_t idx = i_m0 / kargs.block_scale_size_q; + float q_descale = q_descale_ptr[idx]; + // BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8 + // Both P and rowsum are scaled by 2^shift, canceling in normalization + // No additional scaling needed in p_compute_element_func or o_acc_element_func + + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + scales(q_descale), // s_acc_element_func + identity{}, // p_compute_element_func - No scaling (done in exp2) + identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum) + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + k_descale_ptr, + v_descale_ptr, + kargs.block_scale_size_kv, + sink_value); + } else { return FmhaPipeline{}(q_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index dcccdf541c..2fbc9fdb54 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -57,8 +57,13 @@ struct BlockFmhaPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kHasSink = Problem::kHasSink; + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) @@ -167,6 +172,9 @@ struct BlockFmhaPipelineQRKSVS const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, + const float* k_descale_ptr, + const float* v_descale_ptr, + const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -358,6 +366,13 @@ struct BlockFmhaPipelineQRKSVS static_assert(1 <= k1_loops); do { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + k_descale = k_descale_ptr[kv_idx]; + } // STAGE 1, QK gemm auto k_dram_window = make_tile_window( k_dram_block_window.get_bottom_tensor_view(), @@ -427,11 +442,20 @@ struct BlockFmhaPipelineQRKSVS k_lds_window); schedule_gemm0(); } + // dequant + auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -449,7 +473,7 @@ struct BlockFmhaPipelineQRKSVS { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -466,7 +490,7 @@ struct BlockFmhaPipelineQRKSVS } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -571,7 +595,21 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For BLOCKSCALE: precompute (m - shift) once per row + // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) + // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -579,13 +617,13 @@ struct BlockFmhaPipelineQRKSVS if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { @@ -676,18 +714,39 @@ struct BlockFmhaPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } + move_tile_window(v_dram_window, {0, kK1}); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + v_descale = v_descale_ptr[kv_idx]; + } // STAGE 3, KV gemm + auto o_acc0 = decltype(o_acc){}; + clear_tile(o_acc0); + + auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return o_acc0; + } + else + { + return o_acc; + } + }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_window); @@ -722,11 +781,16 @@ struct BlockFmhaPipelineQRKSVS // tail { block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), v_lds_window); block_sync_lds(); } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); + } } while(++i_total_loops < num_total_loop); // store lse @@ -846,6 +910,9 @@ struct BlockFmhaPipelineQRKSVS block_indices, smem_ptr, dropout, + nullptr, + nullptr, + 1, sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 7224ed3a70..046a2f0b9e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -46,6 +46,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); @@ -64,6 +65,10 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasSink = Problem::kHasSink; + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || !kHasLogitsSoftCap)) || @@ -190,6 +195,9 @@ struct BlockFmhaPipelineQRKSVSAsync const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, + const float* k_descale_ptr, + const float* v_descale_ptr, + const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -403,6 +411,13 @@ struct BlockFmhaPipelineQRKSVSAsync // main loop do { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + k_descale = k_descale_ptr[kv_idx]; + } // STAGE 1, QK gemm clear_tile(s_acc); // initialize C if constexpr(k0_loops > 1) @@ -449,11 +464,20 @@ struct BlockFmhaPipelineQRKSVSAsync sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); } __builtin_amdgcn_sched_barrier(1); + // dequant + auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -471,7 +495,7 @@ struct BlockFmhaPipelineQRKSVSAsync { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -488,7 +512,7 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -630,7 +654,21 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For BLOCKSCALE: precompute (m - shift) once per row + // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) + // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -638,13 +676,13 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { @@ -735,7 +773,27 @@ struct BlockFmhaPipelineQRKSVSAsync #endif }(); + float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + v_descale = v_descale_ptr[kv_idx]; + } // STAGE 3, KV gemm + auto o_acc0 = decltype(o_acc){}; + clear_tile(o_acc0); + + auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return o_acc0; + } + else + { + return o_acc; + } + }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { @@ -745,7 +803,7 @@ struct BlockFmhaPipelineQRKSVSAsync v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf } block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), get_slice_tile( @@ -808,13 +866,19 @@ struct BlockFmhaPipelineQRKSVSAsync { block_sync_lds(); gemm_1( - o_acc, + o_acc_, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), get_slice_tile( v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); } + + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); + } } while(i_total_loops < num_total_loop); // store lse @@ -922,6 +986,9 @@ struct BlockFmhaPipelineQRKSVSAsync block_indices, smem_ptr, dropout, + nullptr, + nullptr, + 1, sink_v); } }; From 0b13697a88e77a733d36b14353df1c0a7ae756df Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 22 Jan 2026 16:07:14 +0800 Subject: [PATCH 05/42] [CK_TILE][FMHA]Add new tile size for async (#3623) * Revert "Revert "[CK_TILE][FMHA] Add new tile size for async (#3586)" (#3613)" This reverts commit 8f75869408210cb85e9eb7ff639c4c9dad1331cb. * Add new tile_size for async pipeline Signed-off-by: Linjun-AMD * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Signed-off-by: Linjun-AMD Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 +++++++- .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 6 +++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index ed86f57232..b59f442663 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -315,7 +315,7 @@ class FmhaFwdApiTrait: assert False def seqtune(self, max_bm0: int) -> str: - if self.bm0 == max_bm0: + if self.bm0 == max_bm0 or self.bm0 == 64: return "true/*fall back to largest tile*/" else: return f"a.seqlen_q <= {self.bm0}" @@ -847,6 +847,11 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) and kernel_ctx.tile.F_bm0 != 128 ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.pipeline.tag != "qr_async" + and kernel_ctx.tile.F_bk0 == 64 + ) ): # non qr_async_trload only support km0=128 tile size when hdim is not 128 # non qr_async only support kn0=128 tile size when hdim is 128 @@ -942,6 +947,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 046a2f0b9e..81bd8d5ab5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -329,6 +329,8 @@ struct BlockFmhaPipelineQRKSVSAsync { if(num_total_loop <= 0) { + buffer_load_fence(0); // rocm-7.1.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) if constexpr(kStoreLSE) { auto lse = @@ -345,10 +347,8 @@ struct BlockFmhaPipelineQRKSVSAsync store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } - buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) - // otherwise will have compute error(maybe compiler bug?) - // Note: here occ are all cleard, return it + // Note: here occ are all cleared, return it return o_acc; } __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check From 8daf6ea3026aebe3481792c03026692631059725 Mon Sep 17 00:00:00 2001 From: ApoorvaKalyani Date: Thu, 22 Jan 2026 09:53:59 +0100 Subject: [PATCH 06/42] Grouped conv_fwd_bias_bnorm_clamp instances and tests (#3525) * Added bias_bnorm_clamp instances. * fwd_bias_bnorm_clamp comp instances * fwd_bias_bnorm_mem_inter and mem_intra instances * fwd_bias_bnorm_merged_group_instances * fwd_bias_bnorm_clamp_conv3d_bf16 and f16 instances * Device level changes for fwd_bias_bnorm_clamp * Added the test to the regression test list. * Removed the part 2 and 2x instances * Removed the irrelevant checks in wmma * Refactored the instances to adapt to new device implementation * Updated the reference and include files * enabling tests * Added missing profiler * Added missing instance entry , deleted by mistake * Reduce bias bnorm clamp instances to only a single generic one. * Clean up cmakelists file * clang-format * Change bias bnorm clamp tests to use monotone initialization values to avoid tiny off-integer gemm results on RDNA3 from blowing up. * Renaming some instance lists and add functions to be more standardized. * Commented out non default instances. --------- Co-authored-by: kiefer --- ...uped_conv_fwd_wmma_cshufflev3_instance.hpp | 45 +++- ...d_convolution_forward_bias_bnorm_clamp.hpp | 57 +++++ ...rward_bias_bnorm_clamp_wmma_cshufflev3.inc | 78 +++++++ .../CMakeLists.txt | 32 +-- ...fflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 65 ++++++ ...ufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 65 ++++++ .../CMakeLists.txt | 78 +++---- ...fflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 65 ++++++ ...ufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 65 ++++++ ...grouped_conv_fwd_bias_bnorm_clamp_impl.hpp | 35 +-- profiler/src/CMakeLists.txt | 3 + ...file_grouped_conv_fwd_bias_bnorm_clamp.cpp | 202 ++++++++++++++++++ test/CMakeLists.txt | 1 + .../CMakeLists.txt | 15 +- ...st_grouped_convnd_fwd_bias_bnorm_clamp.cpp | 35 +-- ...grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp | 35 ++- 16 files changed, 768 insertions(+), 108 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 profiler/src/profile_grouped_conv_fwd_bias_bnorm_clamp.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp index 61b85dd12c..65ac3b7bc5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp @@ -24,9 +24,10 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddClamp = ck::tensor_operation::element_wise::AddClamp; -using Clamp = ck::tensor_operation::element_wise::Clamp; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -40,6 +41,25 @@ static constexpr auto ConvFwdOddC = static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + template ; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version | + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + template && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + } +#endif + } + // layout NDHWGC/GKZYXC/NDHWGK + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif + } +#endif // CK_USE_WMMA + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc new file mode 100644 index 0000000000..e2ad6df07e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp_wmma_cshufflev3.inc @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index cf1eaf0e12..d089663f37 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# XDL_AND_WMMA_KERNELS set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) @@ -69,15 +69,6 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl ) -set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances - TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in - NUM_SHARDS 3 - SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP - OUTPUT_DIR ${GENERATED_DIR}/xdl -) - # large tensor # NHWGC, GKYXC, NHWGK @@ -89,7 +80,6 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances @@ -108,6 +98,15 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances + TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instance.in + NUM_SHARDS 3 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances @@ -193,7 +192,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances @@ -325,4 +324,11 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) -add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP}) +#WMMA_Cshuffle_v3 +add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance + wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP} +) + + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..4186771720 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); + + // Note: Commented out temporarily , might be used later. + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + // 2, + // NHWGC, + // GKYXC, + // Tuple, + // NHWGK, + // ConvFwd1x1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + // 2, + // NHWGC, + // GKYXC, + // Tuple, + // NHWGK, + // ConvFwd1x1S1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..6d69352e4c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); + + // Note: Commented out temporarily , might be used later. + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + // 2, + // NHWGC, + // GKYXC, + // Tuple, + // NHWGK, + // ConvFwd1x1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + // 2, + // NHWGC, + // GKYXC, + // Tuple, + // NHWGK, + // ConvFwd1x1S1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index 9796c561c0..dc759cbb54 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS -set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP) +# XDL_AND_WMMA_KERNELS +set(GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) @@ -11,7 +11,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -20,7 +20,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -29,7 +29,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -38,7 +38,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -47,7 +47,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.in NUM_SHARDS 4 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -56,7 +56,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.in NUM_SHARDS 4 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -65,7 +65,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) @@ -74,7 +74,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instances TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16x16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) # large tensor @@ -85,16 +85,17 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) @@ -103,7 +104,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.in NUM_SHARDS 2 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) @@ -112,7 +113,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in NUM_SHARDS 2 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) @@ -124,7 +125,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) @@ -133,7 +134,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) @@ -142,7 +143,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) @@ -151,7 +152,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in NUM_SHARDS 3 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) #mem @@ -162,16 +163,15 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.in NUM_SHARDS 20 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.in NUM_SHARDS 20 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -180,7 +180,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) # NDHWGC, GKZYXC, NDHWGK @@ -190,7 +190,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -199,7 +199,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.in NUM_SHARDS 20 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -208,7 +208,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.in NUM_SHARDS 20 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -217,7 +217,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -226,7 +226,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in NUM_SHARDS 16 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) @@ -238,7 +238,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in NUM_SHARDS 11 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -247,7 +247,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in NUM_SHARDS 1 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -256,7 +256,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.in NUM_SHARDS 4 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -265,7 +265,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in NUM_SHARDS 4 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -274,7 +274,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instance.in NUM_SHARDS 1 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -283,7 +283,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instance.in NUM_SHARDS 1 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -292,7 +292,7 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instance.in NUM_SHARDS 5 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) @@ -301,8 +301,14 @@ generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instance.in NUM_SHARDS 12 - SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + SRC_LIST GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) -add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance ${GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP}) +#WMMA_Cshuffle_v3 +add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance + wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + ${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP} +) + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..b67e6e7c7c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); + + // Note: Commented out temporarily , might be used later. + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + // 3, + // NDHWGC, + // GKZYXC, + // Tuple, + // NDHWGK, + // ConvFwd1x1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_bf16_generic_instances< + // 3, + // NDHWGC, + // GKZYXC, + // Tuple, + // NDHWGK, + // ConvFwd1x1S1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..0bddf9b8f3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); + + // Note: Commented out temporarily , might be used later. + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + // 3, + // NDHWGC, + // GKZYXC, + // Tuple, + // NDHWGK, + // ConvFwd1x1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); + + // add_device_operation_instances(instances, + // device_grouped_conv_fwd_wmma_cshufflev3_f16_generic_instances< + // 3, + // NDHWGC, + // GKZYXC, + // Tuple, + // NDHWGK, + // ConvFwd1x1S1P0, + // Tuple, + // BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp index 22ff02676a..e47cc72b60 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp @@ -122,12 +122,12 @@ template -bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - const ck::utils::conv::ConvParam& conv_param, - int instance_index = -1) +bool profile_grouped_conv_fwd_bias_bnorm_clamp_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + int instance_index = -1) { const float floor = 0.f; const float ceil = 2048.f; @@ -198,18 +198,29 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, std::cout << "scale: " << scale.mDesc << std::endl; std::cout << "shift: " << shift.mDesc << std::endl; + // Note: For the integer initialization method (which is used for verification in the tests), I + // changed the initialization ranges such that the overall operation becomes monotone. This + // means that all multiplications are positive, and all additions are positive. Without this, + // the outelementop can make small relative errors arbitrarily large by shifting them toward + // zero. In this specific case this should not be an issue, since small integer inputs should + // lead to exact outputs from the gemm. However, this is not the case on RDNA3, where integer + // inputs can lead to slightly off-integer outputs. This is another issue to investigate, but it + // remains the case that the outelementop blowing up tiny errors is not reasonable, so changing + // the operation to monotone for now. If we want to move away from monotone we would need to + // have a proper error propagation analysis, which is much more complicated. switch(init_method) { case 0: break; case 1: - input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + input.GenerateTensorValue(GeneratorTensor_2{0, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{0, 5}); - bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{0, 5}); + // Mean is negative because this is subtracted. + mean.GenerateTensorValue(GeneratorTensor_2{-5, 0}); variance.GenerateTensorValue(GeneratorTensor_2{0, 5}); - scale.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - shift.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + scale.GenerateTensorValue(GeneratorTensor_2{0, 5}); + shift.GenerateTensorValue(GeneratorTensor_2{0, 5}); break; default: input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 3379fd15d1..012d6e1502 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -100,6 +100,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp) @@ -240,6 +241,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance) diff --git a/profiler/src/profile_grouped_conv_fwd_bias_bnorm_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_bnorm_clamp.cpp new file mode 100644 index 0000000000..179317bb28 --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_bias_bnorm_clamp.cpp @@ -0,0 +1,202 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 + F32_F32_F32_TF32, // 8 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_bias_bnorm_clamp" +#define OP_DESC "Grouped Convolution Forward+Bias+Bnorm+Clamp" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8\n" + << " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_bias_bnorm_clamp(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + using TF32 = ck::tf32_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = ck::profiler::profile_grouped_conv_fwd_bias_bnorm_clamp_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_bias_bnorm_clamp); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ef2ac098ac..b0b5f1c82f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -33,6 +33,7 @@ set(REGRESSION_TESTS test_convnd_fwd test_convnd_bwd_data test_grouped_convnd_fwd + test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_scaleadd_ab test_grouped_convnd_bwd_weight test_softmax_rank3 diff --git a/test/grouped_convnd_fwd_activation/CMakeLists.txt b/test/grouped_convnd_fwd_activation/CMakeLists.txt index e87ef77e6d..4808f82101 100644 --- a/test/grouped_convnd_fwd_activation/CMakeLists.txt +++ b/test/grouped_convnd_fwd_activation/CMakeLists.txt @@ -1,15 +1,6 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -if(GPU_TARGETS MATCHES "gfx9|gfx12") - #Fail on gfx11 CI but fail to reproduce it in local, disable it temporary - add_gtest_executable(test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_bias_bnorm_clamp.cpp) - target_link_libraries(test_grouped_convnd_fwd_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) - - add_gtest_executable(test_grouped_convnd_fwd_gk_bias_bnorm_clamp test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp) - target_link_libraries(test_grouped_convnd_fwd_gk_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) -endif() - if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp) target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) @@ -26,4 +17,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_fwd_scale test_grouped_convnd_fwd_scale.cpp) target_link_libraries(test_grouped_convnd_fwd_scale PRIVATE utility device_grouped_conv3d_fwd_scale_instance) + + add_gtest_executable(test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_bias_bnorm_clamp.cpp) + target_link_libraries(test_grouped_convnd_fwd_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) + + add_gtest_executable(test_grouped_convnd_fwd_gk_bias_bnorm_clamp test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp) + target_link_libraries(test_grouped_convnd_fwd_gk_bias_bnorm_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_bnorm_clamp_instance device_grouped_conv3d_fwd_bias_bnorm_clamp_instance) endif() diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_bnorm_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_bnorm_clamp.cpp index c54b218739..93007131ab 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_bnorm_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_bnorm_clamp.cpp @@ -39,23 +39,24 @@ class TestGroupedConvndFwd : public ::testing::Test continue; } auto& param = conv_params[i]; - pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( - true, // do_verification - 1, // init_method: integer value - false, // do_log - false, // time_kernel - param, - instance_index); + pass = pass && + ck::profiler::profile_grouped_conv_fwd_bias_bnorm_clamp_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + instance_index); } EXPECT_TRUE(pass); } diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp index 8d0024354b..e17cd70d97 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_bnorm_clamp.cpp @@ -38,24 +38,23 @@ class TestGroupedConvndFwd : public ::testing::Test continue; } auto& param = conv_params[i]; - pass = pass && - ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( - true, // do_verification - 1, // init_method: integer value - false, // do_log - false, // time_kernel - param, - instance_index); + pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_bnorm_clamp_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, + DataType, + DataType, + DataType, + DataType, + DataType, + IndexType, + true /*ElementwiseGK*/>(true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + instance_index); } EXPECT_TRUE(pass); } From 44f481a45ca75b234ba60fdc3dc68974b1b86164 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 22 Jan 2026 15:11:18 +0100 Subject: [PATCH 07/42] [CK TILE] Fix basic gemm pipelines (#3611) * [CK TILE] Fix basic pipelines * fixes --- .../20_grouped_convolution/conv_configs.hpp | 18 + .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 596 ++++++++++++------ .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 368 ++++++----- 3 files changed, 597 insertions(+), 385 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index 620b505820..847030fffb 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -257,6 +257,24 @@ struct ConvTypeConfig template struct PipelineTypeTraits; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV1; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV2; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAGmemBGmemCRegV2; +}; + template <> struct PipelineTypeTraits { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 9b7213837a..60453d8d51 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -39,6 +39,8 @@ struct BaseGemmPipelineAGmemBGmemCRegV1 template struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1 { + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -123,227 +125,411 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1(); } - template ::value && - is_detected::value, - bool>* = nullptr> - CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const + template + struct PipelineImpl : public PipelineImplBase { - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; + }; - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - constexpr bool is_a_col_major = std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_a_col_major - ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "A block window has incorrect lengths for defined ALayout!"); - static_assert(is_b_row_major - ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) - : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), - "B block window has incorrect lengths for defined BLayout!"); - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); - - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), - kLdsAlignmentInBytes) * - kLdsAlignmentInBytes; - - // B tile in LDS - BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); - - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); - - // A DRAM tile window for load - auto as_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeADramTileDistribution()); - }, - number{}); - - // A LDS tile window for store - auto a_copy_lds_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // B DRAM tile window for load - auto bs_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - // B LDS tile window for store - auto b_copy_lds_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); - - // A LDS tile for block GEMM - auto a_lds_gemm_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_lds_load_tile_distr); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_lds_load_tile_distr); - - // Block GEMM - auto block_gemm = BlockGemm(); - - // Acc register tile - auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; - - // prefetch - // global read 0 - // Load tile — during value loading, an elementwise function is executed for each A0, - // A1, … AN. The values A0, A1, … AN are read by the same thread. - auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - - // Load tile — during value loading, an elementwise function is executed for each B0, - // B1, … BN. The values B0, B1, … BN are read by the same thread. - auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const { - // move to 1 - // Move each A — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - // Move each B — the enhanced function move_tile_window is executed, which takes a tuple - // as input. - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + kLdsAlignmentInBytes) * + kLdsAlignmentInBytes; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + + // prefetch + // global read 0 + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - // LDS write 0 - if constexpr(is_a_col_major) { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp); - } - else - { - store_tile(a_copy_lds_window, elementwise_As_res); + // move to 1 + // Move each A — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + // Move each B — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write 0 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } } - // LDS write 0 - if constexpr(is_b_row_major) + index_t iCounter = num_loop - 1; + while(iCounter > 0) { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp); + // global read i + 1 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + block_sync_lds(); + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp_loop); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write i + 1 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp_loop); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } + + iCounter--; } - else + + // tail { - store_tile(b_copy_lds_window, elementwise_Bs_res); + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } + + return c_block_tile; } + }; - index_t iCounter = num_loop - 1; - while(iCounter > 0) + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const { - // global read i + 1 - elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; - block_sync_lds(); + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); - // GEMM i - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; - block_sync_lds(); + static_assert(is_a_col_major + ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); - // move to i + 2 - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + kLdsAlignmentInBytes) * + kLdsAlignmentInBytes; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + + // prefetch + // global read 0 + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - // LDS write i + 1 - if constexpr(is_a_col_major) { - auto a_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); - store_tile(a_copy_lds_window, a_shuffle_tmp_loop); - } - else - { - store_tile(a_copy_lds_window, elementwise_As_res); + // move to 1 + // Move each A — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + // Move each B — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write 0 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } } - // LDS write i + 1 - if constexpr(is_b_row_major) + index_t iCounter = num_loop - 1; + while(iCounter > 0) { - auto b_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); - store_tile(b_copy_lds_window, b_shuffle_tmp_loop); - } - else - { - store_tile(b_copy_lds_window, elementwise_Bs_res); + // global read i + 1 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + block_sync_lds(); + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + // move to i + 2 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp_loop); + } + else + { + store_tile(a_copy_lds_window, elementwise_As_res); + } + + // LDS write i + 1 + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp_loop); + } + else + { + store_tile(b_copy_lds_window, elementwise_Bs_res); + } + + iCounter--; } - iCounter--; + // tail + { + block_sync_lds(); + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + return c_block_tile; } - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } - - return c_block_tile; - } + }; template {}.operator()( a_dram_block_window_tmp, [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, @@ -379,6 +565,28 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.operator()(a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index c711c768ec..35ae2085ca 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -38,6 +38,8 @@ struct BaseGemmPipelineAGmemBGmemCRegV2 template struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2 { + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -56,6 +58,8 @@ struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2>; using BDataType = remove_cvref_t>; + using BlockGemm = remove_cvref_t())>; + static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = @@ -127,205 +131,187 @@ struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2(); } + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == + BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / + APackedSize, + 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + + // prefetch + // global read 0 + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + { + // move to 1 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + store_tile(a_copy_lds_window, elementwise_As_res); + // global read 1 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // LDS write 0 + store_tile(b_copy_lds_window, elementwise_Bs_res); + // global read 1 + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + } + + index_t iCounter = num_loop - 2; + + do + { + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + store_tile(a_copy_lds_window, elementwise_As_res); + // global read i + 2 + elementwise_As_res = + load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // LDS write i + 1 + store_tile(b_copy_lds_window, elementwise_Bs_res); + // global read i + 2 + elementwise_Bs_res = + load_tile_with_elementwise(bs_copy_dram_window, b_element_func); + + iCounter--; + + } while(iCounter > 0); + + // tail + { + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // GEMM num_loop - 2 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // LDS write num_loop - 1 + store_tile(a_copy_lds_window, elementwise_As_res); + + store_tile(b_copy_lds_window, elementwise_Bs_res); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + return c_block_tile; + } + }; + template ::value && is_detected::value, bool>* = nullptr> - CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BsDramBlockWindowTmp& b_dram_block_window_tmp, - const BElementFunction& b_element_func, - index_t num_loop, - void* p_smem) const - { - - using ADramBlockWindowTmp = - remove_cvref_t{}, AsDramBlockWindowTmp>>; - using BDramBlockWindowTmp = - remove_cvref_t{}, BsDramBlockWindowTmp>>; - - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); - - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil( - sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / APackedSize, 16) * - 16; - - // B tile in LDS - BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); - - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); - - // A DRAM tile window for load - auto as_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeADramTileDistribution()); - }, - number{}); - - // A LDS tile window for store - auto a_copy_lds_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - as_copy_dram_window[number<0>{}].get_tile_distribution()); - - // B DRAM tile window for load - auto bs_copy_dram_window = generate_tuple( - [&](auto idx) { - return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeBDramTileDistribution()); - }, - number{}); - - // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - bs_copy_dram_window[number<0>{}].get_tile_distribution()); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); - - // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(decltype(block_gemm)::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(decltype(block_gemm)::MakeBBlockDistributionEncode()); - - // A LDS tile for block GEMM - auto a_lds_gemm_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_lds_load_tile_distr); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_lds_load_tile_distr); - - // Acc register tile - auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; - - // prefetch - // global read 0 - // Load tile — during value loading, an elementwise function is executed for each A0, - // A1, … AN. The values A0, A1, … AN are read by the same thread. - auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - // Load tile — during value loading, an elementwise function is executed for each B0, - // B1, … BN. The values B0, B1, … BN are read by the same thread. - auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - - { - // move to 1 - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - store_tile(a_copy_lds_window, elementwise_As_res); - // global read 1 - elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - - // LDS write 0 - store_tile(b_copy_lds_window, elementwise_Bs_res); - // global read 1 - elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - } - - index_t iCounter = num_loop - 2; - - do - { - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // move to i + 2 - move_tile_window(as_copy_dram_window, {0, kKPerBlock}); - move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - store_tile(a_copy_lds_window, elementwise_As_res); - // global read i + 2 - elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); - - // LDS write i + 1 - store_tile(b_copy_lds_window, elementwise_Bs_res); - // global read i + 2 - elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); - - iCounter--; - - } while(iCounter > 0); - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 2 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // LDS write num_loop - 1 - store_tile(a_copy_lds_window, elementwise_As_res); - - store_tile(b_copy_lds_window, elementwise_Bs_res); - - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } - - return c_block_tile; - } - - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { - return operator()( + return PipelineImpl{}.operator()( a_dram_block_window_tmp, [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, From 9e049a32a11267d7584c498dda11e9febfa7e9e9 Mon Sep 17 00:00:00 2001 From: Vidyasagar Ananthan Date: Thu, 22 Jan 2026 09:34:33 -0800 Subject: [PATCH 08/42] Adding dispatcher architecture (#3300) * WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable. --- .gitignore | 16 + CHANGELOG.md | 1 + dispatcher/CMakeLists.txt | 117 + dispatcher/README.md | 736 ++++++ dispatcher/bindings/README.md | 109 + dispatcher/bindings/ctypes/CMakeLists.txt | 181 ++ .../bindings/ctypes/conv_bwdw_ctypes_lib.cpp | 175 ++ .../bindings/ctypes/conv_ctypes_lib.cpp | 411 +++ .../bindings/ctypes/gemm_ctypes_lib.cpp | 401 +++ dispatcher/bindings/ctypes/gpu_helper.cpp | 206 ++ dispatcher/codegen/ADDING_NEW_GPU.md | 197 ++ dispatcher/codegen/CMakeLists.txt | 125 + dispatcher/codegen/README.md | 123 + dispatcher/codegen/arch_filter.py | 1012 +++++++ dispatcher/codegen/arch_specs.json | 270 ++ dispatcher/codegen/arch_specs_generated.py | 358 +++ dispatcher/codegen/default_config.json | 27 + dispatcher/codegen/generate_arch_specs.py | 452 ++++ .../generate_dispatcher_registration.py | 429 +++ .../codegen/generate_kernel_wrappers.py | 430 +++ dispatcher/codegen/kernel_config_loader.py | 798 ++++++ dispatcher/codegen/preselected_kernels.py | 518 ++++ dispatcher/codegen/unified_gemm_codegen.py | 1713 ++++++++++++ dispatcher/examples/CMakeLists.txt | 448 ++++ dispatcher/examples/README.md | 210 ++ .../examples/gemm/cpp/01_basic_gemm.cpp | 243 ++ .../examples/gemm/cpp/02_multi_size.cpp | 215 ++ .../gemm/cpp/03_benchmark_validation.cpp | 344 +++ .../examples/gemm/cpp/04_heuristics.cpp | 168 ++ .../examples/gemm/cpp/05_json_export.cpp | 127 + .../examples/gemm/cpp/06_multi_registry.cpp | 294 +++ dispatcher/examples/gemm/cpp/README.md | 229 ++ .../examples/gemm/python/01_basic_gemm.py | 331 +++ .../examples/gemm/python/02_batch_gemm.py | 149 ++ .../examples/gemm/python/03_benchmark.py | 171 ++ .../examples/gemm/python/04_validation.py | 156 ++ .../gemm/python/05_numpy_integration.py | 166 ++ .../examples/gemm/python/06_json_export.py | 169 ++ .../examples/gemm/python/07_stress_test.py | 513 ++++ .../examples/gemm/python/08_heuristics.py | 718 +++++ .../examples/gemm/python/09_multi_registry.py | 220 ++ .../gemm/python/10_advanced_benchmark.py | 260 ++ .../examples/gemm/python/11_json_import.py | 310 +++ dispatcher/examples/gemm/python/README.md | 299 +++ dispatcher/examples/gemm/python/kernels.json | 80 + dispatcher/include/ck_tile/dispatcher.hpp | 19 + .../include/ck_tile/dispatcher/README.md | 161 ++ .../ck_tile/dispatcher/arch_filter.hpp | 393 +++ .../dispatcher/arch_specs_generated.hpp | 168 ++ .../backends/generated_kernel_backend.hpp | 143 + .../backends/generated_tile_backend.hpp | 157 ++ .../backends/kernel_registration.hpp | 109 + .../dispatcher/backends/tile_backend.hpp | 173 ++ .../include/ck_tile/dispatcher/dispatcher.hpp | 146 + .../ck_tile/dispatcher/example_args.hpp | 230 ++ .../ck_tile/dispatcher/json_export.hpp | 370 +++ .../ck_tile/dispatcher/kernel_config.hpp | 370 +++ .../ck_tile/dispatcher/kernel_decl.hpp | 509 ++++ .../ck_tile/dispatcher/kernel_instance.hpp | 68 + .../include/ck_tile/dispatcher/kernel_key.hpp | 428 +++ .../include/ck_tile/dispatcher/problem.hpp | 311 +++ .../include/ck_tile/dispatcher/registry.hpp | 197 ++ .../include/ck_tile/dispatcher/utils.hpp | 724 +++++ .../validation/reference_kernels.hpp | 228 ++ dispatcher/python/CMakeLists.txt | 9 + dispatcher/python/README.md | 60 + dispatcher/python/ctypes_utils.py | 2347 +++++++++++++++++ dispatcher/python/pytest.ini | 43 + dispatcher/python/requirements.txt | 22 + dispatcher/scripts/compile_gemm_examples.py | 2253 ++++++++++++++++ dispatcher/scripts/example_kernel_builder.py | 1447 ++++++++++ dispatcher/scripts/parallel_kernel_builder.py | 142 + dispatcher/scripts/stress_test_autocorrect.py | 540 ++++ dispatcher/src/dispatcher.cpp | 152 ++ dispatcher/src/registry.cpp | 288 ++ dispatcher/tests/CMakeLists.txt | 343 +++ dispatcher/tests/test_autocorrect.py | 625 +++++ dispatcher/tests/test_dispatcher.cpp | 296 +++ dispatcher/tests/test_dispatcher_extended.cpp | 499 ++++ dispatcher/tests/test_examples_integration.py | 337 +++ dispatcher/tests/test_json_export.cpp | 448 ++++ dispatcher/tests/test_kernel_key.cpp | 147 ++ dispatcher/tests/test_kernel_key_extended.cpp | 453 ++++ dispatcher/tests/test_minimal.cpp | 57 + dispatcher/tests/test_mock_kernel.cpp | 6 + dispatcher/tests/test_mock_kernel.hpp | 134 + dispatcher/tests/test_problem.cpp | 96 + dispatcher/tests/test_problem_extended.cpp | 457 ++++ .../tests/test_real_kernel_correctness.cpp | 232 ++ .../tests/test_real_kernel_multi_size.cpp | 213 ++ .../tests/test_real_kernel_performance.cpp | 173 ++ dispatcher/tests/test_real_kernel_simple.cpp | 201 ++ dispatcher/tests/test_registry.cpp | 166 ++ dispatcher/tests/test_registry_extended.cpp | 503 ++++ dispatcher/tests/test_regression.cpp | 492 ++++ dispatcher/tests/test_sanity_ck_tile.cpp | 607 +++++ dispatcher/tests/test_tile_backend.cpp | 155 ++ 97 files changed, 33472 insertions(+) create mode 100644 dispatcher/CMakeLists.txt create mode 100644 dispatcher/README.md create mode 100644 dispatcher/bindings/README.md create mode 100644 dispatcher/bindings/ctypes/CMakeLists.txt create mode 100644 dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp create mode 100644 dispatcher/bindings/ctypes/conv_ctypes_lib.cpp create mode 100644 dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp create mode 100644 dispatcher/bindings/ctypes/gpu_helper.cpp create mode 100644 dispatcher/codegen/ADDING_NEW_GPU.md create mode 100644 dispatcher/codegen/CMakeLists.txt create mode 100644 dispatcher/codegen/README.md create mode 100644 dispatcher/codegen/arch_filter.py create mode 100644 dispatcher/codegen/arch_specs.json create mode 100644 dispatcher/codegen/arch_specs_generated.py create mode 100644 dispatcher/codegen/default_config.json create mode 100644 dispatcher/codegen/generate_arch_specs.py create mode 100644 dispatcher/codegen/generate_dispatcher_registration.py create mode 100644 dispatcher/codegen/generate_kernel_wrappers.py create mode 100644 dispatcher/codegen/kernel_config_loader.py create mode 100644 dispatcher/codegen/preselected_kernels.py create mode 100755 dispatcher/codegen/unified_gemm_codegen.py create mode 100644 dispatcher/examples/CMakeLists.txt create mode 100644 dispatcher/examples/README.md create mode 100644 dispatcher/examples/gemm/cpp/01_basic_gemm.cpp create mode 100644 dispatcher/examples/gemm/cpp/02_multi_size.cpp create mode 100644 dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp create mode 100644 dispatcher/examples/gemm/cpp/04_heuristics.cpp create mode 100644 dispatcher/examples/gemm/cpp/05_json_export.cpp create mode 100644 dispatcher/examples/gemm/cpp/06_multi_registry.cpp create mode 100644 dispatcher/examples/gemm/cpp/README.md create mode 100644 dispatcher/examples/gemm/python/01_basic_gemm.py create mode 100644 dispatcher/examples/gemm/python/02_batch_gemm.py create mode 100644 dispatcher/examples/gemm/python/03_benchmark.py create mode 100644 dispatcher/examples/gemm/python/04_validation.py create mode 100644 dispatcher/examples/gemm/python/05_numpy_integration.py create mode 100644 dispatcher/examples/gemm/python/06_json_export.py create mode 100644 dispatcher/examples/gemm/python/07_stress_test.py create mode 100644 dispatcher/examples/gemm/python/08_heuristics.py create mode 100644 dispatcher/examples/gemm/python/09_multi_registry.py create mode 100644 dispatcher/examples/gemm/python/10_advanced_benchmark.py create mode 100644 dispatcher/examples/gemm/python/11_json_import.py create mode 100644 dispatcher/examples/gemm/python/README.md create mode 100644 dispatcher/examples/gemm/python/kernels.json create mode 100644 dispatcher/include/ck_tile/dispatcher.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/README.md create mode 100644 dispatcher/include/ck_tile/dispatcher/arch_filter.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/dispatcher.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/example_args.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/json_export.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/kernel_config.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/kernel_key.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/problem.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/registry.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/utils.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp create mode 100644 dispatcher/python/CMakeLists.txt create mode 100644 dispatcher/python/README.md create mode 100644 dispatcher/python/ctypes_utils.py create mode 100644 dispatcher/python/pytest.ini create mode 100644 dispatcher/python/requirements.txt create mode 100644 dispatcher/scripts/compile_gemm_examples.py create mode 100755 dispatcher/scripts/example_kernel_builder.py create mode 100755 dispatcher/scripts/parallel_kernel_builder.py create mode 100644 dispatcher/scripts/stress_test_autocorrect.py create mode 100644 dispatcher/src/dispatcher.cpp create mode 100644 dispatcher/src/registry.cpp create mode 100644 dispatcher/tests/CMakeLists.txt create mode 100644 dispatcher/tests/test_autocorrect.py create mode 100644 dispatcher/tests/test_dispatcher.cpp create mode 100644 dispatcher/tests/test_dispatcher_extended.cpp create mode 100644 dispatcher/tests/test_examples_integration.py create mode 100644 dispatcher/tests/test_json_export.cpp create mode 100644 dispatcher/tests/test_kernel_key.cpp create mode 100644 dispatcher/tests/test_kernel_key_extended.cpp create mode 100644 dispatcher/tests/test_minimal.cpp create mode 100644 dispatcher/tests/test_mock_kernel.cpp create mode 100644 dispatcher/tests/test_mock_kernel.hpp create mode 100644 dispatcher/tests/test_problem.cpp create mode 100644 dispatcher/tests/test_problem_extended.cpp create mode 100644 dispatcher/tests/test_real_kernel_correctness.cpp create mode 100644 dispatcher/tests/test_real_kernel_multi_size.cpp create mode 100644 dispatcher/tests/test_real_kernel_performance.cpp create mode 100644 dispatcher/tests/test_real_kernel_simple.cpp create mode 100644 dispatcher/tests/test_registry.cpp create mode 100644 dispatcher/tests/test_registry_extended.cpp create mode 100644 dispatcher/tests/test_regression.cpp create mode 100644 dispatcher/tests/test_sanity_ck_tile.cpp create mode 100644 dispatcher/tests/test_tile_backend.cpp diff --git a/.gitignore b/.gitignore index 740d5464fb..a2fb1473ab 100644 --- a/.gitignore +++ b/.gitignore @@ -81,7 +81,23 @@ CMakeUserPresets.json # Python cache __pycache__/ +# Cache directories .cache/ +.ck_tile_cache/ +ck_tile_cache/ +**/kernel_cache/ +**/.kernel_cache/ + +# Dispatcher kernel cache (user-generated, can be large) +dispatcher/**/kernel_cache/ +dispatcher/**/.kernel_cache/ +dispatcher/**/cached_kernels/ +dispatcher/**/*.hsaco +dispatcher/**/*.co + +# Dispatcher generated JSON exports +dispatcher/**/*_kernels.json +dispatcher/**/dispatcher_kernels.json # Generated test data test_data/* diff --git a/CHANGELOG.md b/CHANGELOG.md index dfb50e9bdd..5f17a4d768 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM. diff --git a/dispatcher/CMakeLists.txt b/dispatcher/CMakeLists.txt new file mode 100644 index 0000000000..2acc73d1d5 --- /dev/null +++ b/dispatcher/CMakeLists.txt @@ -0,0 +1,117 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +cmake_minimum_required(VERSION 3.16) + +project(ck_tile_dispatcher VERSION 1.0.0 LANGUAGES CXX) + +# C++17 required +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# Find HIP for headers (needed for validation kernels) +find_package(hip QUIET) +if(NOT hip_FOUND) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip) + find_package(hip REQUIRED) +endif() + +# Dispatcher library +add_library(ck_tile_dispatcher + src/registry.cpp + src/dispatcher.cpp +) + +# Enable PIC for Python bindings +set_target_properties(ck_tile_dispatcher PROPERTIES + POSITION_INDEPENDENT_CODE ON +) + +target_include_directories(ck_tile_dispatcher + PUBLIC + $ + $ +) + +# Link against CK Tile headers (header-only) +target_include_directories(ck_tile_dispatcher + PUBLIC + $ + $ +) + +# Link against HIP headers if available +if(hip_FOUND) + target_link_libraries(ck_tile_dispatcher PUBLIC hip::host) +endif() + +# Compiler warnings +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + target_compile_options(ck_tile_dispatcher PRIVATE + -Wall -Wextra -Wpedantic + ) +elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + target_compile_options(ck_tile_dispatcher PRIVATE + /W4 + ) +endif() + +# Optional: Build tests +option(BUILD_DISPATCHER_TESTS "Build dispatcher unit tests" OFF) +if(BUILD_DISPATCHER_TESTS) + enable_testing() + add_subdirectory(tests) +endif() + +# Optional: Build Python bindings +option(BUILD_DISPATCHER_PYTHON "Build Python bindings for dispatcher" OFF) +if(BUILD_DISPATCHER_PYTHON) + add_subdirectory(python) +endif() + +# Optional: Codegen for tile_engine integration +option(DISPATCHER_AUTO_GENERATE_WRAPPERS "Auto-generate wrappers from tile_engine" OFF) +if(DISPATCHER_AUTO_GENERATE_WRAPPERS) + add_subdirectory(codegen) +endif() + +# Optional: Build examples +option(BUILD_DISPATCHER_EXAMPLES "Build dispatcher examples" OFF) +if(BUILD_DISPATCHER_EXAMPLES) + add_subdirectory(examples) +endif() + +# Optional: Build ctypes bindings +option(BUILD_DISPATCHER_BINDINGS "Build language bindings for dispatcher" OFF) +if(BUILD_DISPATCHER_BINDINGS) + add_subdirectory(bindings/ctypes) +endif() + +# If codegen is enabled, add generated include directory +if(DISPATCHER_AUTO_GENERATE_WRAPPERS AND DISPATCHER_GENERATED_INCLUDE_DIR) + target_include_directories(ck_tile_dispatcher + PUBLIC + $ + ) +endif() + +# Installation +install(TARGETS ck_tile_dispatcher + EXPORT ck_tile_dispatcher_targets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin +) + +install(DIRECTORY include/ + DESTINATION include + FILES_MATCHING PATTERN "*.hpp" +) + +install(EXPORT ck_tile_dispatcher_targets + FILE ck_tile_dispatcher_targets.cmake + NAMESPACE ck_tile:: + DESTINATION lib/cmake/ck_tile_dispatcher +) + diff --git a/dispatcher/README.md b/dispatcher/README.md new file mode 100644 index 0000000000..fa3fbd3a59 --- /dev/null +++ b/dispatcher/README.md @@ -0,0 +1,736 @@ +# CK Tile Dispatcher + +A unified kernel dispatch system for AMD GPUs with C++ and Python frontends. + +**Validated Platform:** AMD Instinct MI300 series (gfx942) + + +--- + +## Table of Contents + +1. [Quick Start](#quick-start) +2. [Docker Setup](#docker-setup-recommended) +3. [Prerequisites](#prerequisites) +4. [Step-by-Step Build Guide](#step-by-step-build-guide) +5. [Running Examples](#running-examples) +6. [External Integration](#external-integration) +7. [Core Concepts](#core-concepts) +8. [Troubleshooting](#troubleshooting) +9. [File Structure](#file-structure) + +--- + +## Quick Start + +**Complete setup from scratch (5 minutes):** + +```bash +# From the composable_kernel root directory +cd dispatcher + +# Step 1: Create build directory +mkdir -p build && cd build + +# Step 2: Configure CMake +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Step 3: Generate kernels and build (CMake handles this automatically) +make -j$(nproc) + +# Step 4: Run C++ examples +./examples/gemm_01_basic + +# Step 5: Build Python libraries (required for Python examples) +make python_libs + +# Step 6: Run Python examples (from dispatcher directory) +cd .. +python3 examples/gemm/python/01_basic_gemm.py +``` + +--- + +## Docker Setup (Recommended) + +For a reproducible build environment, use the official ROCm Docker image: + +### Step 1: Pull and Run Container + +```bash +# Pull the CK Docker image +docker pull rocm/composable_kernel:ck_ub24.04_rocm7.0.1 + +# Run container with GPU access +docker run \ + -it \ + --privileged \ + --device=/dev/kfd \ + --device=/dev/dri \ + --group-add video \ + --group-add render \ + -w /root/workspace \ + -v $(pwd):/root/workspace \ + rocm/composable_kernel:ck_ub24.04_rocm7.0.1 \ + /bin/bash +``` + +> **Note:** Omit `--device` flags if building without GPU access. + +### Step 2: Clone and Build + +```bash +# Inside the container +git clone https://github.com/ROCm/composable_kernel.git +cd composable_kernel +git checkout builder-dispatch-tile-gemm + +# Set up Python environment +python3 -m venv .venv +source .venv/bin/activate +pip install numpy + +# Build dispatcher +cd dispatcher +mkdir -p build && cd build +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +make -j$(nproc) +``` + +### One-Liner Build (inside container) + +```bash +git clone https://github.com/ROCm/composable_kernel.git && \ +cd composable_kernel && git checkout builder-dispatch-tile-gemm && \ +python3 -m venv .venv && source .venv/bin/activate && pip install numpy && \ +cd dispatcher && mkdir -p build && cd build && \ +cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release -DGPU_TARGETS="gfx942" -DBUILD_DISPATCHER_EXAMPLES=ON && \ +make -j$(nproc) +``` + +--- + +## Prerequisites + +### Required Software + +| Software | Minimum Version | Check Command | +|----------|-----------------|---------------| +| ROCm | 6.4+ | `rocminfo` | +| CMake | 3.16+ | `cmake --version` | +| Python | 3.8+ | `python3 --version` | +| NumPy | 1.20+ | `pip show numpy` | +| hipcc | (from ROCm) | `/opt/rocm/bin/hipcc --version` | + +> **Note:** Newer GPU targets (gfx950, gfx1201) require ROCm 6.3+. For ROCm 6.4+, you can also use `amdclang++` instead of `hipcc`. + +### Check Your GPU Architecture + +```bash +# Find your GPU architecture +rocminfo | grep -i "gfx" +# Example output: "gfx942" +``` + +**Supported architectures:** +- **gfx942** - MI300X, MI300A, MI308, MI325 (Instinct MI300 series) +- **gfx90a** - MI200 series (MI250, MI250X) +- **gfx950** - MI350 series +- **gfx1101** - RDNA3 series +- **gfx1201** - RDNA4 series + +### Install Python Dependencies + +NumPy is required for Python examples and kernel generation. We recommend using a virtual environment: + +**Option 1: Using standard venv** +```bash +# Create virtual environment +python3 -m venv .venv + +# Activate virtual environment +source .venv/bin/activate # Linux/macOS +# .venv\Scripts\activate # Windows + +# Install NumPy +pip install numpy +``` + +**Option 2: Using uv (faster alternative)** +```bash +# Install uv if not already installed +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Create and activate virtual environment +uv venv .venv +source .venv/bin/activate # Linux/macOS +# .venv\Scripts\activate # Windows + +# Install NumPy +uv pip install numpy +``` + +**Option 3: System-wide install (not recommended)** +```bash +pip install numpy +``` + +> **Note:** Always activate your virtual environment before running CMake or Python examples. + +### Supported Data Types + +CK Tile supports a wide range of data types for GEMM operations: + +| A dtype | B dtype | Acc dtype | Warp Tile Sizes | Notes | +|---------|---------|-----------|-----------------|-------| +| `fp32` | `fp32` | `fp32` | 16x16x4, 16x16x16 | Full precision | +| `fp16` | `fp16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Standard half | +| `bf16` | `bf16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Brain float 16 | +| `fp8` | `fp8` | `fp32` | 32x32x16, 32x32x32, 16x16x32, 16x16x64 | FP8 E4M3 | +| `fp8` | `bf8` | `fp32` | 32x32x16, 16x16x32 | Mixed FP8/BF8 | +| `bf8` | `fp8` | `fp32` | 32x32x16, 16x16x128 | Mixed BF8/FP8 | +| `bf8` | `bf8` | `fp32` | 32x32x16, 32x32x32, 16x16x32 | BF8 E5M2 | +| `int8` | `int8` | `int32` | 32x32x16, 16x16x32, 16x16x16 | Integer GEMM | +| `pk_fp4` | `pk_fp4` | `fp32` | 16x16x128 | Packed 4-bit float | + +**Notes:** +- Accumulator is always `fp32` except for `int8` which uses `int32` +- FP8 types: `fp8` = E4M3, `bf8` = E5M2 +- `pk_fp4` = Packed 4-bit float (2 values per byte) +- Some dtypes require specific GPU architectures (e.g., FP8 requires MI300+) + +--- + +## Step-by-Step Build Guide + +### Step 1: Navigate to Dispatcher Directory + +```bash +# From composable_kernel root +cd dispatcher + +# Verify you're in the right place +ls CMakeLists.txt # Should exist +``` + +### Step 2: Create Build Directory + +```bash +mkdir -p build +cd build +``` + +### Step 3: Configure CMake + +**Basic configuration (library only):** +```bash +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" +``` + +**Full configuration (with examples and tests):** +```bash +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON \ + -DBUILD_DISPATCHER_TESTS=ON +``` + +**Expected output:** +``` +-- Found hip: /opt/rocm (found suitable version "6.x.x") +-- Generating GEMM kernels... +-- Built: gemm_01 through gemm_06, dispatcher_gemm_lib.so +-- Configuring done +``` + +### Step 4: Build + +```bash +# Build all targets (generates kernels automatically, then compiles) +make -j$(nproc) + +# Or build specific targets +make gemm_01_basic # Single GEMM example +make dispatcher_gemm_lib # GEMM shared library for Python + +# Build ONLY Python libraries (faster if you don't need C++ examples) +make python_libs -j$(nproc) +``` + +### Kernel Generation Targets + +Kernels are generated automatically during `make`, but you can also control generation explicitly: + +```bash +# Generate all kernels only (no compilation) +make generate_all_kernels + +# Generate GEMM kernels only +make generate_gemm_kernels + +# Force regenerate (even if kernels exist) +make regenerate_all_kernels +make regenerate_gemm_kernels + +# Generate for specific GPU architecture +make generate_kernels_gfx942 # MI300X +make generate_kernels_gfx90a # MI200 +make generate_kernels_gfx1100 # RDNA3 +``` + +### Step 5: Verify Build + +```bash +# Check executables were built +ls examples/gemm_* + +# Check shared libraries were built +ls examples/libdispatcher_gemm_lib.so +``` + +### CMake Options Reference + +| Flag | Default | Description | +|------|---------|-------------| +| `CMAKE_BUILD_TYPE` | Debug | **Use `Release` for performance!** | +| `GPU_TARGETS` | None | Target GPU: `"gfx942"`, `"gfx90a"`, etc. | +| `BUILD_DISPATCHER_EXAMPLES` | OFF | Build C++ examples and Python libs | +| `BUILD_DISPATCHER_TESTS` | OFF | Build unit tests | +| `CMAKE_PREFIX_PATH` | - | ROCm installation path | +| `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler | + +⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower. +⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories). + +--- + +## Running Examples + +### C++ Examples + +After building, executables are in `build/examples/`: + +```bash +cd build/examples + +# GEMM Examples +./gemm_01_basic # Basic GEMM with autofill/autocorrect +./gemm_02_multi_size # Wildcard expansion +./gemm_03_benchmark_validation # Benchmarking + validation +./gemm_04_heuristics # Heuristic kernel selection +./gemm_05_json_export # Registry JSON export +./gemm_06_multi_registry # Multiple registries +``` + +### Python Examples + +Run from the `dispatcher` directory: + +```bash +cd /path/to/composable_kernel/dispatcher + +# GEMM Examples +python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM +python3 examples/gemm/python/04_validation.py # CPU reference validation +python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels) +python3 examples/gemm/python/08_heuristics.py # Heuristic selection +``` + +### Example Output + +**Expected C++ output (`gemm_01_basic`):** +``` +====================================================================== +Example 01: Basic GEMM with Declarative Kernel Definition +====================================================================== + +Step 1: Declared Kernels +------------------------ +Kernel Set: fp16_gemm_kernels + Architecture: gfx942 + Configurations: 1 + - gemm_fp16_rcr_compv4_cshuffle_intrawave_128x128x32 + +Step 2: Create Registry and Dispatcher +-------------------------------------- + Registered 1 kernels + +Step 3: Define Problem +---------------------- + M=1024, N=1024, K=1024 + +Step 4: GPU Execution +--------------------- + *** GPU EXECUTION *** + Time: ms + TFLOPS: +``` + +> **Note:** Timing values vary by GPU model and system configuration. + +--- + +## Benchmark Parameters + +The dispatcher supports fine-grained control over benchmarking, matching CK Tile's `stream_config`: + +### Available Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `warmup` | int | 5 | Warmup iterations (discarded from timing) | +| `repeat` | int | 20 | Benchmark iterations (averaged) | +| `flush_cache` | bool | false | Flush GPU L2 cache between iterations | +| `rotating_count` | int | 1 | Rotating buffer count (for cache simulation) | +| `timer` | string | "gpu" | Timer type: "gpu" (HIP events) or "cpu" | +| `init` | string | "random" | Matrix initialization: "random", "linear", "constant" | +| `split_k` | int | 1 | Split-K parallelism factor | + +### Python Usage + +```python +from ctypes_utils import DispatcherLib + +# Basic usage (default benchmark settings) +lib = DispatcherLib.load() + +# Advanced benchmark settings via command line +python3 examples/gemm/python/10_advanced_benchmark.py \ + --warmup 10 \ + --repeat 100 \ + --flush-cache +``` + +### C++ Usage + +```cpp +// Basic timing +ck_tile::stream_config cfg{nullptr, true}; + +// Advanced benchmark settings +ck_tile::stream_config cfg{ + nullptr, // stream_id (nullptr = default stream) + true, // time_kernel + 1, // log_level + 10, // cold_niters (warmup) + 100, // nrepeat + true, // is_gpu_timer + true, // flush_cache + 4 // rotating_count +}; + +float avg_time = kernel.run(args, cfg); +``` + +### Command Line (Python Examples) + +```bash +# Basic run +python3 examples/gemm/python/10_advanced_benchmark.py + +# With benchmark parameters +python3 examples/gemm/python/10_advanced_benchmark.py \ + --warmup 10 \ + --repeat 100 \ + --flush-cache \ + --rotating-count 4 \ + --timer gpu +``` + +### When to Use Each Parameter + +| Use Case | Recommended Settings | +|----------|---------------------| +| Quick test | `warmup=1, repeat=3` | +| Stable benchmark | `warmup=10, repeat=100` | +| Memory-bound analysis | `flush_cache=True, rotating_count=4` | +| Compute-bound analysis | `flush_cache=False` (default) | +| Debug timing | `timer="cpu"` | +| Production | `timer="gpu"` (default) | + +--- + +## External Integration + +### Using Dispatcher in Your Own Project + +#### Option 1: CMake Integration (Recommended) + +Add to your `CMakeLists.txt`: + +```cmake +# Set path to composable_kernel +set(CK_ROOT "/path/to/composable_kernel") + +# Add dispatcher subdirectory +add_subdirectory(${CK_ROOT}/dispatcher dispatcher_build) + +# Link to your target +target_link_libraries(your_target PRIVATE ck_tile_dispatcher) +target_include_directories(your_target PRIVATE + ${CK_ROOT}/dispatcher/include + ${CK_ROOT}/include +) +``` + +#### Option 2: Include as Pre-built Library + +```cmake +# Find the pre-built library +find_library(CK_DISPATCHER ck_tile_dispatcher + PATHS /path/to/composable_kernel/dispatcher/build) + +# Include directories +set(CK_INCLUDE_DIRS + /path/to/composable_kernel/include + /path/to/composable_kernel/dispatcher/include +) + +target_link_libraries(your_target PRIVATE ${CK_DISPATCHER}) +target_include_directories(your_target PRIVATE ${CK_INCLUDE_DIRS}) +``` + +#### Option 3: Python Integration + +```python +import sys +sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python") + +# For GEMM +from ctypes_utils import DispatcherLib, Dispatcher, KernelConfig +``` + +### Required Include Paths + +When integrating, you need these include paths: + +``` +/path/to/composable_kernel/include # CK Tile core headers +/path/to/composable_kernel/dispatcher/include # Dispatcher headers +/path/to/composable_kernel/dispatcher/build/generated_kernels # Generated kernels +``` + +### Required Compile Flags + +```bash +# Minimum flags for hipcc +-std=c++17 +-D__HIP_PLATFORM_AMD__=1 +--offload-arch=gfx942 # Your target GPU + +# Recommended flags +-O3 +-mllvm -enable-noalias-to-md-conversion=0 +-Wno-undefined-func-template +-Wno-float-equal +-Wall +-Werror +``` + +### Python Path Setup + +For Python scripts outside the dispatcher directory: + +```bash +# Option 1: Environment variable +export PYTHONPATH="/path/to/composable_kernel/dispatcher/examples/gemm/python:$PYTHONPATH" + +# Option 2: In your Python script +import sys +sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python") +``` + +### Library Search Paths + +The Python utilities search for the shared library in these locations: + +```python +# For GEMM (ctypes_utils.py) +SEARCH_PATHS = [ + "build/examples/libdispatcher_gemm_lib.so", + "../build/examples/libdispatcher_gemm_lib.so", + "../../build/examples/libdispatcher_gemm_lib.so", +] +``` + +If using from a different location, set the library path explicitly: + +```python +# GEMM +from ctypes_utils import DispatcherLib +lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so") +``` + +--- + +## Core Concepts + +### Data Flow + +``` +KernelConfig → Registry → Dispatcher → GPU Execution +``` + +1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts) +2. **Registry**: Stores multiple kernel configurations +3. **Dispatcher**: Selects best kernel for a given problem and executes it + +### GEMM Layouts + +| Layout | A | B | C | Use Case | +|--------|---|---|---|----------| +| RCR | Row | Col | Row | Most common (PyTorch default) | +| RRR | Row | Row | Row | Both inputs row-major | +| CRR | Col | Row | Row | A transposed | +| CCR | Col | Col | Row | Both inputs column-major | + +### Split-K Support + +Split-K divides the K dimension across multiple thread blocks, useful for large K dimensions. + +**Usage (C++):** +```cpp +// GEMM with 4-way K split +auto problem = ProblemBuilder() + .m(1024).n(1024).k(8192) + .split_k(4) + .build(); +``` + +--- + +## Troubleshooting + +### Build Issues + +| Problem | Solution | +|---------|----------| +| `hipcc not found` | Set `-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc` | +| `hip not found` | Set `-DCMAKE_PREFIX_PATH=/opt/rocm` | +| Very slow performance | Use `-DCMAKE_BUILD_TYPE=Release` | +| `gfx942 not supported` | Check ROCm version (need 6.0+) | +| Kernel generation fails | Ensure Python 3.8+ with NumPy installed in active venv | +| Build errors | First verify CK builds without dispatcher (see main CK README) | + +### Runtime Issues + +| Problem | Solution | +|---------|----------| +| `Library not found` | Build with `-DBUILD_DISPATCHER_EXAMPLES=ON` | +| `No kernel found` | Check GPU arch matches build target | +| Python `ModuleNotFoundError` | Add paths to `PYTHONPATH` (see above) | +| Wrong results | Verify layout matches your data | + +### Debug Commands + +```bash +# Check ROCm installation +rocminfo | head -20 + +# Check GPU architecture +rocminfo | grep "Name:" + +# Verify library exists +ls -la build/examples/libdispatcher_*.so + +# Run with verbose output +./build/examples/gemm_01_basic 2>&1 + +# Python: Check library loading +python3 -c " +import ctypes +lib = ctypes.CDLL('/path/to/libdispatcher_gemm_lib.so') +print('Library loaded successfully') +" +``` + +### Clean Rebuild + +If you encounter issues, try a clean rebuild: + +```bash +cd dispatcher +rm -rf build +mkdir build && cd build +cmake .. [your options] +make -j$(nproc) +``` + +--- + +## File Structure + +``` +dispatcher/ +├── README.md # This file +├── CMakeLists.txt # Build configuration +│ +├── include/ck_tile/dispatcher/ # C++ headers +│ ├── dispatcher.hpp # GEMM dispatcher +│ ├── registry.hpp # Kernel registry +│ └── kernel_key.hpp # Kernel configuration +│ +├── src/ # C++ implementation +│ +├── codegen/ # Kernel generation +│ ├── unified_gemm_codegen.py # GEMM kernel generator +│ └── arch_specs.json # GPU specifications +│ +├── bindings/ctypes/ # Python ctypes interface +│ └── gemm_ctypes_lib.cpp # GEMM Python library +│ +├── examples/ # Examples +│ └── gemm/ +│ ├── cpp/ # C++ GEMM examples (01-06) +│ └── python/ # Python GEMM examples (01-11) +│ +├── scripts/ # Build scripts +│ +└── tests/ # Unit tests +``` + +--- + +## Example Documentation + +| Directory | README | +|-----------|--------| +| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) | +| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) | +| Codegen | [codegen/README.md](codegen/README.md) | + +--- + +## Archived Content + +Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`: +- `examples/conv/cpp/` - 11 C++ convolution examples +- `examples/conv/python/` - 14 Python convolution examples +- `codegen/unified_conv_codegen.py` - Conv kernel generator +- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers +- `python/conv_utils.py` - Conv Python utilities + +--- + +## License + +MIT License - Copyright (c) 2025, Advanced Micro Devices, Inc. diff --git a/dispatcher/bindings/README.md b/dispatcher/bindings/README.md new file mode 100644 index 0000000000..7cda21f6ec --- /dev/null +++ b/dispatcher/bindings/README.md @@ -0,0 +1,109 @@ +# CK Tile Dispatcher - Language Bindings + +This directory contains language bindings for the CK Tile Dispatcher. + +## Structure + +``` +bindings/ +├── ctypes/ # Python ctypes bindings (C API) +│ ├── gemm_ctypes_lib.cpp # GEMM dispatcher C API +│ ├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data) +│ ├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API +│ ├── gpu_helper.cpp # CLI helper for Python +│ └── CMakeLists.txt +└── README.md +``` + +## ctypes Bindings + +The ctypes bindings provide a C API that Python can load via `ctypes.CDLL()`. + +### Building + +```bash +cd build +cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm +make dispatcher_gemm_lib dispatcher_conv_lib gpu_helper +``` + +### Usage from Python + +```python +import ctypes + +# Load the library +lib = ctypes.CDLL("path/to/libdispatcher_gemm_lib.so") + +# Initialize +lib.dispatcher_init() + +# Check if problem is supported +is_supported = lib.dispatcher_is_supported(M, N, K) + +# Run GEMM +time_ms = ctypes.c_float() +result = lib.dispatcher_run_gemm( + A_ptr, B_ptr, C_ptr, + M, N, K, + ctypes.byref(time_ms) +) + +# Cleanup +lib.dispatcher_cleanup() +``` + +### GEMM API + +| Function | Description | +|----------|-------------| +| `dispatcher_init()` | Initialize the dispatcher | +| `dispatcher_is_supported(M, N, K)` | Check if problem size is supported | +| `dispatcher_select_kernel(M, N, K, name_buf, buf_size)` | Get kernel name for problem | +| `dispatcher_run_gemm(A, B, C, M, N, K, time_ms)` | Execute GEMM | +| `dispatcher_get_kernel_count()` | Get number of registered kernels | +| `dispatcher_export_registry_json()` | Export registry as JSON | +| `dispatcher_cleanup()` | Release resources | + +### Convolution API + +| Function | Description | +|----------|-------------| +| `conv_dispatcher_init()` | Initialize the dispatcher | +| `conv_dispatcher_is_supported(prob)` | Check if problem is supported | +| `conv_dispatcher_select_kernel(prob, name_buf, buf_size)` | Get kernel name | +| `conv_dispatcher_run(input, weight, output, prob, stream)` | Execute convolution | +| `conv_dispatcher_get_kernel_count()` | Get number of registered kernels | +| `conv_dispatcher_cleanup()` | Release resources | + +## GPU Helper + +The `gpu_helper` executable provides a CLI interface for Python: + +```bash +./gpu_helper 1024 1024 1024 --validate +``` + +Output is JSON for easy parsing: +```json +{ + "problem": {"M": 1024, "N": 1024, "K": 1024}, + "kernel": "gemm_fp16_rcr_...", + "execution": { + "time_ms": 0.5, + "tflops": 4.2 + }, + "validation": { + "accuracy": 100.0 + }, + "status": "success" +} +``` + +## Examples + +See the examples that use these bindings: + +- **GEMM**: `dispatcher/examples/gemm/python/` +- **Conv**: `dispatcher/examples/conv/python/` + diff --git a/dispatcher/bindings/ctypes/CMakeLists.txt b/dispatcher/bindings/ctypes/CMakeLists.txt new file mode 100644 index 0000000000..804e5e9bd7 --- /dev/null +++ b/dispatcher/bindings/ctypes/CMakeLists.txt @@ -0,0 +1,181 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# ============================================================================= +# CK Tile Dispatcher - ctypes Bindings +# ============================================================================= +# +# Provides shared libraries with C API for Python ctypes integration. +# +# Targets: +# - dispatcher_gemm_lib : GEMM dispatcher library +# - dispatcher_conv_lib : Convolution dispatcher library (forward + bwd_data) +# - dispatcher_conv_bwdw_lib : Convolution backward weight library +# - gpu_helper : GPU helper executable for Python +# + +cmake_minimum_required(VERSION 3.16) + +# Helper function to add a ctypes library +function(add_ctypes_library TARGET_NAME SOURCE_FILE) + cmake_parse_arguments(ARG "CONV" "KERNEL_HEADER" "" ${ARGN}) + + add_library(${TARGET_NAME} SHARED ${SOURCE_FILE}) + + target_include_directories(${TARGET_NAME} PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + + target_link_libraries(${TARGET_NAME} PRIVATE + hip::device + ) + + # Force-include kernel header if provided + if(ARG_KERNEL_HEADER AND EXISTS ${ARG_KERNEL_HEADER}) + target_compile_options(${TARGET_NAME} PRIVATE + -include ${ARG_KERNEL_HEADER} + ) + if(ARG_CONV) + target_compile_definitions(${TARGET_NAME} PRIVATE CONV_KERNEL_AVAILABLE) + endif() + endif() + + set_target_properties(${TARGET_NAME} PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) +endfunction() + +# ============================================================================= +# GEMM ctypes Library +# ============================================================================= + +# Find a generated GEMM kernel header for the library +file(GLOB GEMM_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/gemm_*.hpp") +if(GEMM_KERNEL_HEADERS) + list(GET GEMM_KERNEL_HEADERS 0 GEMM_KERNEL_HEADER) + message(STATUS "Found GEMM kernel for ctypes lib: ${GEMM_KERNEL_HEADER}") + + add_ctypes_library(dispatcher_gemm_lib + gemm_ctypes_lib.cpp + KERNEL_HEADER ${GEMM_KERNEL_HEADER} + ) +else() + message(STATUS "No GEMM kernel found for ctypes lib - building without kernel") + add_library(dispatcher_gemm_lib SHARED gemm_ctypes_lib.cpp) + target_include_directories(dispatcher_gemm_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device) +endif() + +# ============================================================================= +# Convolution ctypes Library (supports forward + bwd_data) +# ============================================================================= + +# Look for forward kernels +file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp") +# Look for backward data kernels +file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp") +# Fallback: any conv kernel (for backwards compatibility) +file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp") + +add_library(dispatcher_conv_lib SHARED conv_ctypes_lib.cpp) +target_include_directories(dispatcher_conv_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include +) +target_link_libraries(dispatcher_conv_lib PRIVATE hip::device) +set_target_properties(dispatcher_conv_lib PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 +) + +# Add forward kernel if available +if(CONV_FWD_KERNEL_HEADERS) + list(GET CONV_FWD_KERNEL_HEADERS 0 CONV_FWD_KERNEL_HEADER) + message(STATUS "Found Conv FWD kernel for ctypes lib: ${CONV_FWD_KERNEL_HEADER}") + target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_FWD_KERNEL_HEADER}) + target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE) +elseif(CONV_KERNEL_HEADERS) + # Fallback to any conv kernel + list(GET CONV_KERNEL_HEADERS 0 CONV_KERNEL_HEADER) + message(STATUS "Found Conv kernel for ctypes lib: ${CONV_KERNEL_HEADER}") + target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_KERNEL_HEADER}) + target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE) +else() + message(STATUS "No Conv FWD kernel found for ctypes lib - building without kernel") +endif() + +# Add backward data kernel if available +if(CONV_BWDD_KERNEL_HEADERS) + list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER) + message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}") + target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER}) + target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE) +endif() + +# ============================================================================= +# Convolution Backward Weight ctypes Library (separate lib for bwd_weight) +# ============================================================================= + +file(GLOB CONV_BWDW_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*bwd_weight*.hpp") +if(CONV_BWDW_KERNEL_HEADERS) + list(GET CONV_BWDW_KERNEL_HEADERS 0 CONV_BWDW_KERNEL_HEADER) + message(STATUS "Found Conv BwdWeight kernel for ctypes lib: ${CONV_BWDW_KERNEL_HEADER}") + + add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp) + target_include_directories(dispatcher_conv_bwdw_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device) + target_compile_options(dispatcher_conv_bwdw_lib PRIVATE + -include ${CONV_BWDW_KERNEL_HEADER} + ) + target_compile_definitions(dispatcher_conv_bwdw_lib PRIVATE CONV_BWD_WEIGHT_AVAILABLE) + set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) +else() + message(STATUS "No Conv BwdWeight kernel found for ctypes lib - building without kernel") + add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp) + target_include_directories(dispatcher_conv_bwdw_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device) + set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) +endif() + +# ============================================================================= +# GPU Helper Executable +# ============================================================================= + +if(GEMM_KERNEL_HEADERS) + add_executable(gpu_helper gpu_helper.cpp) + + target_include_directories(gpu_helper PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + + target_link_libraries(gpu_helper PRIVATE + hip::device + ) + + target_compile_options(gpu_helper PRIVATE + -include ${GEMM_KERNEL_HEADER} + ) + + set_target_properties(gpu_helper PROPERTIES + CXX_STANDARD 17 + ) +endif() + diff --git a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp new file mode 100644 index 0000000000..09e058f80f --- /dev/null +++ b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp @@ -0,0 +1,175 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Convolution Backward Weight Dispatcher ctypes Library + * + * SEPARATE library for backward weight to avoid template conflicts with + * forward/backward_data kernels in the main conv_ctypes_lib. + * + * Usage from Python: + * lib = ctypes.CDLL("libdispatcher_conv_bwdw_lib.so") + * lib.conv_bwdw_init() + * lib.conv_bwdw_run(...) + */ + +#include +#include +#include + +// Minimal includes - matching the C++ example +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/gemm.hpp" // Must be before grouped_convolution for TileGemmTraits +#include "ck_tile/ops/grouped_convolution.hpp" + +// Global state - minimal, no registry needed for direct launch +static bool g_bwdw_initialized = false; + +extern "C" { + +// ============================================================================= +// Initialization (minimal - just sets flag) +// ============================================================================= + +int conv_bwdw_init() +{ + g_bwdw_initialized = true; + return 0; // Return 0 on success (consistent with other init functions) +} + +void conv_bwdw_cleanup() { g_bwdw_initialized = false; } + +// ============================================================================= +// Problem Structure (same as main library) +// ============================================================================= + +struct ConvBwdwProblemC +{ + int N, G, C, K; + int input_d, input_h, input_w; + int filter_z, filter_y, filter_x; + int stride_d, stride_h, stride_w; + int pad_d, pad_h, pad_w; + int dilation_d, dilation_h, dilation_w; +}; + +// ============================================================================= +// Backward Weight Execution +// ============================================================================= + +#ifdef CONV_BWD_WEIGHT_AVAILABLE +static ck_tile::conv::ConvParam build_conv_param(const ConvBwdwProblemC* prob) +{ + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + + if(is_3d) + { + return ck_tile::conv::ConvParam{3, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_z, prob->filter_y, prob->filter_x}, + {prob->input_d, prob->input_h, prob->input_w}, + {prob->stride_d, prob->stride_h, prob->stride_w}, + {prob->dilation_d, prob->dilation_h, prob->dilation_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}}; + } + else + { + return ck_tile::conv::ConvParam{2, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_y, prob->filter_x}, + {prob->input_h, prob->input_w}, + {prob->stride_h, prob->stride_w}, + {prob->dilation_h, prob->dilation_w}, + {prob->pad_h, prob->pad_w}, + {prob->pad_h, prob->pad_w}}; + } +} + +static float run_bwd_weight_impl(const void* input_ptr, + const void* grad_output_ptr, + void* grad_weight_ptr, + const ConvBwdwProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + // Backward weight: A=input, B=grad_output, C=grad_weight + ck_tile::GroupedConvBwdWeightHostArgs args(conv_param, + input_ptr, // in_ptr = input + grad_weight_ptr, // wei_ptr = grad_weight (output) + {}, // ds_ptr + grad_output_ptr, // out_ptr = grad_output + 1 // k_batch + ); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + return SelectedConvBwdWeightLauncher::launch(args, stream_cfg); +} +#endif + +float conv_bwdw_run(const void* input_ptr, + const void* grad_output_ptr, + void* grad_weight_ptr, + const ConvBwdwProblemC* prob, + void* stream) +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + // Validate all required pointers before kernel launch + if(!g_bwdw_initialized || !prob) + return -1.0f; + if(!input_ptr || !grad_output_ptr || !grad_weight_ptr) + return -1.0f; // Null data pointer would cause kernel crash + return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream); +#else + return -1.0f; +#endif +} + +// ============================================================================= +// Info +// ============================================================================= + +const char* conv_bwdw_version() { return "1.0.0"; } + +int conv_bwdw_has_kernels() +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + return 1; +#else + return 0; +#endif +} + +int conv_bwdw_get_kernel_count() +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + return 1; +#else + return 0; +#endif +} + +int conv_bwdw_get_kernel_name(int index, char* buffer, int buffer_size) +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + if(index != 0 || !buffer || buffer_size <= 0) + return -1; + std::strncpy(buffer, CONV_BWD_WEIGHT_KERNEL_NAME, buffer_size - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +#else + return -1; +#endif +} + +} // extern "C" diff --git a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp new file mode 100644 index 0000000000..d3c64621a7 --- /dev/null +++ b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp @@ -0,0 +1,411 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Convolution Dispatcher ctypes Library + * + * Provides C API for Python ctypes integration. + * Supports forward convolution. Backward operations require additional headers. + * + * REQUIRED: Forward kernel header must be force-included via -include flag. + * OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE + * + * Usage from Python: + * lib = ctypes.CDLL("libdispatcher_conv.so") + * lib.conv_dispatcher_init() + * lib.conv_dispatcher_run(...) + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher/conv_utils.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +using namespace ck_tile::dispatcher; + +// Global state (using shared_ptr for safe memory management) +static std::shared_ptr g_registry = nullptr; +static std::shared_ptr g_dispatcher = nullptr; +static std::vector g_kernels; + +extern "C" { + +// ============================================================================= +// Initialization +// ============================================================================= + +int conv_dispatcher_init() +{ + if(g_registry) + return 0; // Already initialized + + g_registry = std::make_shared(); + g_dispatcher = std::make_shared(g_registry.get()); + + // Register kernel configurations using simple ConvKernelSet + // (actual kernel launch uses the force-included SelectedConvKernelLauncher) + using namespace ck_tile::dispatcher::conv_decl; + + // Forward kernels (required - must be force-included) + // Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb + ConvKernelSet fwd_set; + fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + ConvAlgorithm() + .tile(128, 128, 64) // tile_m x tile_n x tile_k + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave"), + "gfx942"); + g_registry->register_set(fwd_set, ConvRegistry::Priority::High); + +#ifdef CONV_BWD_DATA_AVAILABLE + // Backward data kernels + // Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16 + ConvKernelSet bwd_data_set; + bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), + ConvAlgorithm() + .tile(128, 128, 64) // tile_m x tile_n x tile_k + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave"), + "gfx942"); + g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High); +#endif + + return 0; +} + +int conv_dispatcher_cleanup() +{ + // shared_ptr automatically handles cleanup when reset + g_dispatcher.reset(); + g_registry.reset(); + g_kernels.clear(); + return 0; +} + +// ============================================================================= +// Registry Management +// ============================================================================= + +int conv_dispatcher_get_kernel_count() +{ + if(!g_registry) + return 0; + return static_cast(g_registry->size()); +} + +int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size) +{ + if(index < 0 || !buffer || buffer_size <= 0) + return -1; + + if(!g_registry) + return -1; + + // Use registry to get kernel names (they are registered with full names) + const auto& kernels = g_registry->all_kernels(); + if(static_cast(index) >= kernels.size()) + return -1; + + const auto* kernel = kernels[index]; + std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +} + +// ============================================================================= +// Problem Definition +// ============================================================================= + +struct ConvProblemC +{ + int N, G, C, K; + int input_d, input_h, input_w; + int filter_z, filter_y, filter_x; + int stride_d, stride_h, stride_w; + int pad_d, pad_h, pad_w; + int dilation_d, dilation_h, dilation_w; + int direction; // 0=forward, 1=bwd_data, 2=bwd_weight +}; + +// ============================================================================= +// Kernel Selection +// ============================================================================= + +int conv_dispatcher_is_supported(const ConvProblemC* prob) +{ + if(!g_registry || !prob) + return 0; + + ConvProblem problem; + problem.N = prob->N; + problem.G = prob->G; + problem.C = prob->C; + problem.K = prob->K; + problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; + problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; + problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; + problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; + problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; + problem.op = static_cast(prob->direction); + problem.compute_output_size(); + + const auto* kernel = g_dispatcher->select(problem); + return kernel ? 1 : 0; +} + +int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size) +{ + if(!g_registry || !prob || !kernel_name || buffer_size <= 0) + return -1; + + ConvProblem problem; + problem.N = prob->N; + problem.G = prob->G; + problem.C = prob->C; + problem.K = prob->K; + problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; + problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; + problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; + problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; + problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; + problem.op = static_cast(prob->direction); + problem.compute_output_size(); + + const auto* kernel = g_dispatcher->select(problem); + if(!kernel) + return -1; + + std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1); + kernel_name[buffer_size - 1] = '\0'; + + return 0; +} + +// ============================================================================= +// Convolution Execution +// ============================================================================= + +// Helper to build ConvParam +static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob) +{ + // Determine if this is 2D or 3D convolution + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + + if(is_3d) + { + // 3D convolution: use all spatial dimensions + return ck_tile::conv::ConvParam{3, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_z, prob->filter_y, prob->filter_x}, + {prob->input_d, prob->input_h, prob->input_w}, + {prob->stride_d, prob->stride_h, prob->stride_w}, + {prob->dilation_d, prob->dilation_h, prob->dilation_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}}; + } + else + { + // 2D convolution: only use H, W dimensions + return ck_tile::conv::ConvParam{2, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_y, prob->filter_x}, + {prob->input_h, prob->input_w}, + {prob->stride_h, prob->stride_w}, + {prob->dilation_h, prob->dilation_w}, + {prob->pad_h, prob->pad_w}, + {prob->pad_h, prob->pad_w}}; + } +} + +// Forward convolution (required - kernel header must be force-included) +static float run_forward(const void* input_ptr, + const void* weight_ptr, + void* output_ptr, + const ConvProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + // SelectedConvKernelLauncher is defined in the force-included forward kernel header + return SelectedConvKernelLauncher::launch(args, stream_cfg); +} + +#ifdef CONV_BWD_DATA_AVAILABLE +// Backward data convolution (optional) +// Computes: grad_input = conv_bwd_data(weight, grad_output) +// +// Parameters: +// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT) +// weight_ptr: W - frozen weights (const, read-only INPUT) +// grad_input_ptr: dX - gradient for input (writable, OUTPUT) +static float run_bwd_data(const void* grad_output_ptr, + const void* weight_ptr, + void* grad_input_ptr, + const ConvProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + // CK Tile API uses tensor POSITION names (from forward pass), not data flow: + // in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data) + // wei_ptr = weight tensor = weight_ptr (W, const) + // out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data) + ck_tile::GroupedConvBwdDataHostArgs args( + conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + return SelectedConvBwdDataLauncher::launch(args, stream_cfg); +} +#endif + +#ifdef CONV_BWD_WEIGHT_AVAILABLE +// Backward weight convolution (optional) +// Parameters: +// input_ptr: original forward input X (const, read-only) +// grad_output_ptr: gradient from next layer dY (const, read-only) +// grad_weight_ptr: gradient of weights dW (writable, OUTPUT) +static float run_bwd_weight(const void* input_ptr, + const void* grad_output_ptr, + void* grad_weight_ptr, + const ConvProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + // GroupedConvBwdWeightHostArgs constructor order: + // (param, in=X, wei=dW (output), ds, out=dY (input), k_batch) + // Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output) + ck_tile::GroupedConvBwdWeightHostArgs args( + conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + return SelectedConvBwdWeightLauncher::launch(args, stream_cfg); +} +#endif + +/** + * @brief Execute convolution based on direction specified in prob + * + * Parameter mapping varies by direction: + * Forward (direction=0): + * input_ptr = X (input tensor) + * weight_ptr = W (weight tensor) + * output_ptr = Y (output buffer) + * + * Backward Data (direction=1): + * input_ptr = dY (grad_output - gradient from next layer) + * weight_ptr = W (weight tensor, frozen) + * output_ptr = dX (grad_input buffer) + * + * Backward Weight (direction=2): + * input_ptr = X (forward input tensor) + * weight_ptr = dY (grad_output - gradient from next layer) + * output_ptr = dW (grad_weight buffer) + */ +float conv_dispatcher_run(const void* input_ptr, + const void* weight_ptr, + void* output_ptr, + const ConvProblemC* prob, + void* stream) +{ + // Validate all required pointers before kernel launch + if(!g_dispatcher || !prob) + return -1.0f; + if(!input_ptr || !weight_ptr || !output_ptr) + return -1.0f; // Null data pointer would cause kernel crash + + // Build problem for kernel selection + ConvProblem problem; + problem.N = prob->N; + problem.G = prob->G; + problem.C = prob->C; + problem.K = prob->K; + problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; + problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; + problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; + problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; + problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; + problem.op = static_cast(prob->direction); + problem.compute_output_size(); + + // Select kernel + const auto* kernel = g_dispatcher->select(problem); + if(!kernel) + return -1.0f; + + // Dispatch based on direction + switch(prob->direction) + { + case 0: // Forward (always available) + return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream); + +#ifdef CONV_BWD_DATA_AVAILABLE + case 1: // Backward data + // Convention: caller passes (grad_output, weight, grad_input_buffer) + // in the (input_ptr, weight_ptr, output_ptr) slots respectively. + // run_bwd_data expects: (grad_output, weight, grad_input) + return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream); +#endif + +#ifdef CONV_BWD_WEIGHT_AVAILABLE + case 2: // Backward weight + // Convention: caller passes (input, grad_output, grad_weight_buffer) + // in the (input_ptr, weight_ptr, output_ptr) slots respectively. + // run_bwd_weight expects: (input, grad_output, grad_weight) + return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream); +#endif + + default: return -1.0f; + } +} + +// ============================================================================= +// Info +// ============================================================================= + +const char* conv_dispatcher_version() { return "1.0.0"; } + +int conv_dispatcher_has_kernels() +{ + return 1; // Forward kernel is required +} + +int conv_dispatcher_has_bwd_data() +{ +#ifdef CONV_BWD_DATA_AVAILABLE + return 1; +#else + return 0; +#endif +} + +int conv_dispatcher_has_bwd_weight() +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + return 1; +#else + return 0; +#endif +} + +} // extern "C" diff --git a/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp b/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp new file mode 100644 index 0000000000..85c0c2f2c1 --- /dev/null +++ b/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp @@ -0,0 +1,401 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * GEMM Dispatcher ctypes Library + * + * Provides C API for Python ctypes integration. + * Kernel header included via -include at compile time. + * + * Usage from Python: + * lib = ctypes.CDLL("libdispatcher_gemm.so") + * lib.dispatcher_init() + * lib.dispatcher_run_gemm(...) + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag +// Defines: ADataType, BDataType, CDataType, AccDataType, SelectedKernel, KERNEL_NAME + +// GPU architecture - can be overridden via -DGFX_ARCH="gfx90a" at compile time +#ifndef GFX_ARCH +#define GFX_ARCH "gfx942" +#endif + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +// Global dispatcher (initialized once, managed via shared_ptr for safe cleanup) +static std::shared_ptr g_dispatcher = nullptr; +static bool g_initialized = false; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + return -1; \ + } \ + } + +extern "C" { + +/** + * Initialize dispatcher with a kernel + * Must be called before run_gemm + * + * Returns: 0 on success, -1 on error + */ +int dispatcher_initialize() +{ + if(g_initialized) + { + return 0; // Already initialized + } + + // Create kernel key from the force-included kernel header + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = GFX_ARCH; + + // Register kernel using types from force-included header + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + // Create dispatcher (using shared_ptr for safe memory management) + g_dispatcher = std::make_shared(); + g_initialized = true; + + return 0; +} + +/** + * Get kernel tile configuration + */ +int dispatcher_get_kernel_config(int* tile_m, + int* tile_n, + int* tile_k, + int* warp_tile_m, + int* warp_tile_n, + int* warp_tile_k, + int* warp_m, + int* warp_n, + int* warp_k) +{ + if(!g_initialized) + { + return -1; + } + + auto kernels = Registry::instance().get_all(); + if(kernels.empty()) + { + return -1; + } + + // Get configuration from first kernel + auto& key = kernels[0]->get_key(); + auto& algo = key.algorithm; + + if(tile_m) + *tile_m = algo.tile_shape.m; + if(tile_n) + *tile_n = algo.tile_shape.n; + if(tile_k) + *tile_k = algo.tile_shape.k; + if(warp_tile_m) + *warp_tile_m = algo.warp_tile_shape.m; + if(warp_tile_n) + *warp_tile_n = algo.warp_tile_shape.n; + if(warp_tile_k) + *warp_tile_k = algo.warp_tile_shape.k; + if(warp_m) + *warp_m = algo.wave_shape.m; + if(warp_n) + *warp_n = algo.wave_shape.n; + if(warp_k) + *warp_k = algo.wave_shape.k; + + return 0; +} + +/** + * Get the selected kernel name for a problem + */ +int dispatcher_select_kernel(int64_t M, int64_t N, int64_t K, char* name_buffer, int buffer_size) +{ + if(!g_initialized || !name_buffer || buffer_size <= 0) + { + return -1; + } + + Problem problem(M, N, K); + auto kernel = g_dispatcher->select_kernel(problem); + + if(!kernel) + { + return -1; + } + + std::string name = kernel->get_name(); + strncpy(name_buffer, name.c_str(), buffer_size - 1); + name_buffer[buffer_size - 1] = '\0'; + + return 0; +} + +/** + * Check if a problem size is supported by available kernels + */ +int dispatcher_is_supported(int64_t M, int64_t N, int64_t K) +{ + if(!g_initialized) + { + return 0; + } + + if(M <= 0 || N <= 0 || K <= 0) + { + return 0; + } + + Problem problem(M, N, K); + auto kernel = g_dispatcher->select_kernel(problem); + return kernel != nullptr ? 1 : 0; +} + +/** + * Run GEMM on GPU via dispatcher + */ +int dispatcher_run_gemm( + const void* A, const void* B, void* C, int64_t M, int64_t N, int64_t K, float* time_ms) +{ + if(!g_initialized || !A || !B || !C) + { + return -1; + } + + // First check if any kernel supports this problem + Problem problem(M, N, K); + auto kernel = g_dispatcher->select_kernel(problem); + if(!kernel) + { + if(time_ms) + { + *time_ms = -1.0f; + } + return -2; // No suitable kernel + } + + // Cast to correct types (from force-included header) + const ADataType* A_host = static_cast(A); + const BDataType* B_host = static_cast(B); + CDataType* C_host = static_cast(C); + + // Allocate GPU memory + ADataType* A_dev = nullptr; + BDataType* B_dev = nullptr; + CDataType* C_dev = nullptr; + + auto cleanup_gpu_mem = [&]() { + if(A_dev) + (void)hipFree(A_dev); + if(B_dev) + (void)hipFree(B_dev); + if(C_dev) + (void)hipFree(C_dev); + }; + + if(hipMalloc(&A_dev, M * K * sizeof(ADataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMalloc(&B_dev, K * N * sizeof(BDataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMalloc(&C_dev, M * N * sizeof(CDataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + + // Copy input data to GPU + if(hipMemcpy(A_dev, A_host, M * K * sizeof(ADataType), hipMemcpyHostToDevice) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMemcpy(B_dev, B_host, K * N * sizeof(BDataType), hipMemcpyHostToDevice) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMemset(C_dev, 0, M * N * sizeof(CDataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + + // Run GEMM via dispatcher + float exec_time; + try + { + exec_time = g_dispatcher->run(A_dev, B_dev, C_dev, problem); + } + catch(const std::exception& e) + { + cleanup_gpu_mem(); + return -1; + } + + // Copy result back to host + if(hipMemcpy(C_host, C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + + if(time_ms) + { + *time_ms = exec_time; + } + + cleanup_gpu_mem(); + return 0; +} + +/** + * Get kernel information + */ +const char* dispatcher_get_kernel_name() { return KERNEL_NAME; } + +/** + * Initialize dispatcher (alias) + */ +int dispatcher_init() { return dispatcher_initialize(); } + +/** + * Get the number of registered kernels + */ +int dispatcher_get_kernel_count() { return static_cast(Registry::instance().size()); } + +/** + * Export registry to JSON string + */ +static std::string g_json_buffer; + +const char* dispatcher_export_registry_json() +{ + auto& registry = Registry::instance(); + + std::ostringstream json; + json << "{\n"; + json << " \"metadata\": {\n"; + json << " \"timestamp\": \"" << __DATE__ << " " << __TIME__ << "\",\n"; + json << " \"total_kernels\": " << registry.size() << ",\n"; + json << " \"export_version\": \"1.0\",\n"; + json << " \"dispatcher_version\": \"1.0.0\"\n"; + json << " },\n"; + json << " \"statistics\": {\n"; + json << " \"by_datatype\": {},\n"; + json << " \"by_pipeline\": {},\n"; + json << " \"by_scheduler\": {}\n"; + json << " },\n"; + json << " \"kernels\": [\n"; + + auto kernels = registry.get_all(); + for(size_t i = 0; i < kernels.size(); ++i) + { + auto& kernel = kernels[i]; + auto& key = kernel->get_key(); + auto& algo = key.algorithm; + std::string name = kernel->get_name(); + + json << " {\n"; + json << " \"identifier\": \"" << key.encode_identifier() << "\",\n"; + json << " \"name\": \"" << name << "\",\n"; + json << " \"algorithm\": {\n"; + json << " \"tile_shape\": {\"m\": " << algo.tile_shape.m + << ", \"n\": " << algo.tile_shape.n << ", \"k\": " << algo.tile_shape.k << "},\n"; + json << " \"wave_shape\": {\"m\": " << unsigned(algo.wave_shape.m) + << ", \"n\": " << unsigned(algo.wave_shape.n) + << ", \"k\": " << unsigned(algo.wave_shape.k) << "},\n"; + json << " \"warp_tile_shape\": {\"m\": " << unsigned(algo.warp_tile_shape.m) + << ", \"n\": " << unsigned(algo.warp_tile_shape.n) + << ", \"k\": " << unsigned(algo.warp_tile_shape.k) << "},\n"; + json << " \"block_size\": " << algo.block_size << ",\n"; + json << " \"persistent\": " << (algo.persistent ? "true" : "false") << ",\n"; + json << " \"double_buffer\": " << (algo.double_buffer ? "true" : "false") << ",\n"; + json << " \"preshuffle\": " << (algo.preshuffle ? "true" : "false") << ",\n"; + json << " \"transpose_c\": " << (algo.transpose_c ? "true" : "false") << "\n"; + json << " }\n"; + json << " }"; + if(i < kernels.size() - 1) + { + json << ","; + } + json << "\n"; + } + + json << " ]\n"; + json << "}\n"; + + g_json_buffer = json.str(); + return g_json_buffer.c_str(); +} + +/** + * Cleanup dispatcher resources + */ +void dispatcher_cleanup() +{ + g_dispatcher.reset(); + g_initialized = false; +} + +} // extern "C" diff --git a/dispatcher/bindings/ctypes/gpu_helper.cpp b/dispatcher/bindings/ctypes/gpu_helper.cpp new file mode 100644 index 0000000000..1c72c14e39 --- /dev/null +++ b/dispatcher/bindings/ctypes/gpu_helper.cpp @@ -0,0 +1,206 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * GPU Helper - C++ executable for GPU GEMM execution + * + * A CLI tool for Python to execute GPU GEMM with generated kernels. + * Usage: gpu_helper [--validate] + * + * Kernel header included via -include flag at compile time. + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP_ERROR: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +// CPU reference GEMM (for validation) +template +void cpu_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + // A: RowMajor, B: ColumnMajor + acc += float(A[m * K + k]) * float(B[k + n * K]); + } + C[m * N + n] = T(acc); + } + } +} + +int main(int argc, char** argv) +{ + // Parse arguments + if(argc < 4) + { + std::cerr << "Usage: " << argv[0] << " [--validate]\n"; + std::cerr << "\nOptions:\n"; + std::cerr << " M, N, K : Problem dimensions\n"; + std::cerr << " --validate : Compare GPU results with CPU reference\n"; + return 1; + } + + int M = std::atoi(argv[1]); + int N = std::atoi(argv[2]); + int K = std::atoi(argv[3]); + bool validate = (argc > 4 && std::string(argv[4]) == "--validate"); + + // Output in JSON-like format for easy Python parsing + std::cout << "{" << std::endl; + std::cout << " \"problem\": {\"M\": " << M << ", \"N\": " << N << ", \"K\": " << K << "}," + << std::endl; + std::cout << " \"kernel\": \"" << KERNEL_NAME << "\"," << std::endl; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cout << " \"error\": \"No kernel selected\"" << std::endl; + std::cout << "}" << std::endl; + return 1; + } + + std::cout << " \"selected_kernel\": \"" << selected->get_name() << "\"," << std::endl; + + // Prepare data: A=1, B=1, so C should be K + std::vector A_host(M * K, ADataType(1.0f)); + std::vector B_host(K * N, BDataType(1.0f)); + std::vector C_gpu(M * N); + + // GPU execution + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Calculate performance + double flops = 2.0 * M * N * K; + double tflops = (flops / (gpu_time * 1e-3)) / 1e12; + + std::cout << " \"execution\": {" << std::endl; + std::cout << " \"time_ms\": " << gpu_time << "," << std::endl; + std::cout << " \"tflops\": " << tflops << "," << std::endl; + std::cout << " \"flops\": " << (long long)flops << std::endl; + std::cout << " }," << std::endl; + + // Validation + if(validate) + { + std::vector C_cpu(M * N); + cpu_gemm(A_host, B_host, C_cpu, M, N, K); + + int correct = 0; + float max_error = 0.0f; + + for(int i = 0; i < M * N; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + float error = std::abs(gpu_val - cpu_val) / (std::abs(cpu_val) + 1e-5f); + + max_error = std::max(max_error, error); + + if(error < 0.02f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + + std::cout << " \"validation\": {" << std::endl; + std::cout << " \"accuracy\": " << accuracy << "," << std::endl; + std::cout << " \"max_error\": " << max_error << "," << std::endl; + std::cout << " \"correct_elements\": " << correct << "," << std::endl; + std::cout << " \"total_elements\": " << M * N << std::endl; + std::cout << " }," << std::endl; + } + + std::cout << " \"status\": \"success\"" << std::endl; + std::cout << "}" << std::endl; + + // Cleanup + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + return 0; +} diff --git a/dispatcher/codegen/ADDING_NEW_GPU.md b/dispatcher/codegen/ADDING_NEW_GPU.md new file mode 100644 index 0000000000..0bd2966a85 --- /dev/null +++ b/dispatcher/codegen/ADDING_NEW_GPU.md @@ -0,0 +1,197 @@ +# Adding New GPU Architecture Support + +Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatcher. + +> **See also:** [Main Dispatcher README](../README.md) | [Codegen README](README.md) + +## Overview + +The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications: + +``` +arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python) + → arch_specs_generated.hpp (C++) +``` + +## Quick Start + +```bash +# 1. Edit arch_specs.json +# 2. Run generator +python generate_arch_specs.py +# 3. Rebuild +cd ../build && cmake --build . -j8 +# 4. Test +ctest +``` + +## Step-by-Step Guide + +### Step 1: Edit arch_specs.json + +Add new architecture under `"architectures"`: + +```json +{ + "architectures": { + "gfx1100": { + "family": "rdna3", + "description": "AMD Radeon RX 7000 series (RDNA3)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]], + "bf16_bf16_bf16": [[16, 16, 16], [32, 32, 16]] + } + } + } +} +``` + +### Step 2: Configuration Fields + +| Field | Description | Example | +|-------|-------------|---------| +| `family` | GPU family | `"cdna3"`, `"rdna4"` | +| `description` | Human-readable name | `"AMD Instinct MI300"` | +| `warp_size` | Wave/warp size | `64` (CDNA), `32` (RDNA) | +| `lds_capacity_kb` | LDS memory in KB | `64` | +| `warp_configs` | Valid `[warp_m, warp_n, warp_k]` | `[[2,2,1], [4,4,1]]` | +| `warp_tile_combos` | Warp tiles per dtype | See below | + +### Step 3: Warp Tile Combinations + +Map data type combinations to valid warp tile sizes: + +```json +"warp_tile_combos": { + "fp16_fp16_fp16": [[32, 32, 8], [16, 16, 16], [32, 32, 16]], + "bf16_bf16_bf16": [[32, 32, 8], [16, 16, 16]], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]] +} +``` + +Key format: `{A_dtype}_{B_dtype}_{C_dtype}` + +### Step 4: Run Generator + +```bash +cd dispatcher/codegen +python generate_arch_specs.py +``` + +This generates: +- `arch_specs_generated.py` (Python module) +- `../include/ck_tile/dispatcher/arch_specs_generated.hpp` (C++ header) + +### Step 5: Rebuild and Test + +```bash +cd ../build +cmake --build . -j8 +ctest --output-on-failure +``` + +### Step 6: Verify + +```python +from arch_filter import ArchFilter + +filter = ArchFilter("gfx1100") +is_valid = filter.is_kernel_valid( + datatype_a="fp16", datatype_b="fp16", datatype_c="fp16", + tile_m=128, tile_n=128, tile_k=32, + warp_m=2, warp_n=2, warp_k=1, + warp_tile_m=16, warp_tile_n=16, warp_tile_k=16 +) +print(f"Valid: {is_valid}") +``` + +## Reference + +### Supported Data Types + +| Key | Description | +|-----|-------------| +| `fp16` | Half precision (16-bit) | +| `bf16` | Brain float 16 | +| `fp32` | Single precision (32-bit) | +| `fp64` | Double precision (64-bit) | +| `fp8` | 8-bit float (E4M3) | +| `bf8` | 8-bit brain float (E5M2) | +| `int8` | 8-bit integer | +| `int4` | 4-bit integer | + +### GPU Families + +| Family | Description | +|--------|-------------| +| `cdna2` | MI200 series (gfx90a) | +| `cdna3` | MI300 series (gfx942) | +| `cdna4` | MI350 series (gfx950) | +| `rdna3` | RX 7000 series (gfx1100) | +| `rdna4` | RX 9000 series (gfx1201) | + +### Pipeline LDS Limits + +| Pipeline | LDS Limit | +|----------|-----------| +| `compv4` | 32 KB | +| `preshufflev2` | 32 KB | +| `default` | 64 KB | + +## Troubleshooting + +### "Unknown GPU architecture" + +1. Check architecture key matches exactly (e.g., `"gfx942"` not `"GFX942"`) +2. Verify you ran `generate_arch_specs.py` +3. Rebuild C++ code + +### Kernels being rejected + +```python +from arch_filter import ArchFilter, KernelConfig + +filter = ArchFilter("gfx942") +result = filter.validate_kernel(config) +print(f"Valid: {result.valid}") +for error in result.errors: + print(f" Error: {error}") +``` + +### Missing warp tile combination + +1. Check `warp_tile_combos` in `arch_specs.json` +2. Ensure `[warp_tile_m, warp_tile_n, warp_tile_k]` is in the list +3. Verify data type key format + +## File Structure + +``` +codegen/ +├── arch_specs.json # Single source of truth (EDIT THIS) +├── generate_arch_specs.py # Generator script +├── arch_specs_generated.py # Generated Python module +└── ADDING_NEW_GPU.md # This file + +include/ck_tile/dispatcher/ +├── arch_specs_generated.hpp # Generated C++ header +└── arch_filter.hpp # C++ filter +``` + +## Best Practices + +1. **Test thoroughly** - Run all tests after adding a new GPU +2. **Start minimal** - Add only validated configurations +3. **Document sources** - Note where warp tile combinations came from +4. **Keep in sync** - If using tile_engine, keep both updated + +--- + +> **More info:** See [../README.md](../README.md) for full documentation. diff --git a/dispatcher/codegen/CMakeLists.txt b/dispatcher/codegen/CMakeLists.txt new file mode 100644 index 0000000000..e63dcaab67 --- /dev/null +++ b/dispatcher/codegen/CMakeLists.txt @@ -0,0 +1,125 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Tile GEMM Unified Code Generator + +cmake_minimum_required(VERSION 3.16) + +# Find Python +find_package(Python3 COMPONENTS Interpreter REQUIRED) + +# Configuration +set(CODEGEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/unified_gemm_codegen.py") +set(CODEGEN_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json") +set(CODEGEN_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm") + +# Configurable options +set(CK_TILE_GEMM_DATATYPE "fp16" CACHE STRING "GEMM data type (fp16, bf16, fp32, fp8, bf8, int8)") +set(CK_TILE_GEMM_LAYOUT "rcr" CACHE STRING "GEMM layout (rcr, rrr, crr, ccr)") +set(CK_TILE_GEMM_VARIANTS "standard" CACHE STRING "GEMM variants (standard, preshuffle, multi_d)") +set(CK_TILE_GEMM_GPU_TARGET "gfx942" CACHE STRING "Target GPU architecture") +set(CK_TILE_GEMM_PARALLEL ON CACHE BOOL "Enable parallel generation") + +# Custom target to run code generation +add_custom_target(generate_tile_gemm_kernels + COMMAND ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT} + --output-dir ${CODEGEN_OUTPUT_DIR} + --datatype ${CK_TILE_GEMM_DATATYPE} + --layout ${CK_TILE_GEMM_LAYOUT} + --gpu-target ${CK_TILE_GEMM_GPU_TARGET} + --config ${CODEGEN_CONFIG} + --variants ${CK_TILE_GEMM_VARIANTS} + $<$>:--no-parallel> + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Generating CK Tile GEMM kernels and dispatcher wrappers..." + VERBATIM +) + +# Create output directory +file(MAKE_DIRECTORY ${CODEGEN_OUTPUT_DIR}) + +# Add generated headers to include path +include_directories(${CODEGEN_OUTPUT_DIR}) + +# Installation +install(FILES + ${CODEGEN_SCRIPT} + ${CODEGEN_CONFIG} + README.md + DESTINATION share/ck_tile/codegen +) + +# Helper function for projects to generate kernels +function(ck_tile_generate_gemm_kernels) + set(options PARALLEL) + set(oneValueArgs OUTPUT_DIR DATATYPE LAYOUT GPU_TARGET CONFIG) + set(multiValueArgs VARIANTS) + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + # Set defaults + if(NOT ARG_OUTPUT_DIR) + set(ARG_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm") + endif() + if(NOT ARG_DATATYPE) + set(ARG_DATATYPE "fp16") + endif() + if(NOT ARG_LAYOUT) + set(ARG_LAYOUT "rcr") + endif() + if(NOT ARG_GPU_TARGET) + set(ARG_GPU_TARGET "gfx942") + endif() + if(NOT ARG_CONFIG) + set(ARG_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json") + endif() + if(NOT ARG_VARIANTS) + set(ARG_VARIANTS "standard") + endif() + + # Build command + set(CMD ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT} + --output-dir ${ARG_OUTPUT_DIR} + --datatype ${ARG_DATATYPE} + --layout ${ARG_LAYOUT} + --gpu-target ${ARG_GPU_TARGET} + --config ${ARG_CONFIG} + --variants ${ARG_VARIANTS} + ) + + if(NOT ARG_PARALLEL) + list(APPEND CMD --no-parallel) + endif() + + # Execute + execute_process( + COMMAND ${CMD} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE RESULT + OUTPUT_VARIABLE OUTPUT + ERROR_VARIABLE ERROR + ) + + if(NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to generate GEMM kernels:\n${ERROR}") + else() + message(STATUS "Generated GEMM kernels: ${OUTPUT}") + endif() +endfunction() + +# Example usage documentation +message(STATUS "CK Tile GEMM Code Generator configured") +message(STATUS " Script: ${CODEGEN_SCRIPT}") +message(STATUS " Config: ${CODEGEN_CONFIG}") +message(STATUS " Output: ${CODEGEN_OUTPUT_DIR}") +message(STATUS "") +message(STATUS "To generate kernels:") +message(STATUS " cmake --build . --target generate_tile_gemm_kernels") +message(STATUS "") +message(STATUS "Or use CMake function:") +message(STATUS " ck_tile_generate_gemm_kernels(") +message(STATUS " OUTPUT_DIR ./generated") +message(STATUS " DATATYPE fp16") +message(STATUS " LAYOUT rcr") +message(STATUS " VARIANTS standard preshuffle multi_d") +message(STATUS " PARALLEL") +message(STATUS " )") diff --git a/dispatcher/codegen/README.md b/dispatcher/codegen/README.md new file mode 100644 index 0000000000..2d753924f5 --- /dev/null +++ b/dispatcher/codegen/README.md @@ -0,0 +1,123 @@ +# CK Tile GEMM Unified Code Generator + +Single source of truth for all GEMM kernel generation. + +> **See also:** [Main Dispatcher README](../README.md) for installation and core concepts. + +## Quick Start + +```bash +cd dispatcher/codegen + +# Generate standard FP16 kernels +python3 unified_gemm_codegen.py \ + --output-dir ../build/generated_kernels \ + --datatype fp16 \ + --layout rcr \ + --variants standard + +# Generate all variants +python3 unified_gemm_codegen.py \ + --output-dir ../build/generated_kernels \ + --variants standard preshuffle multi_d +``` + +## Using from Python + +```python +from ctypes_utils import CodegenRunner, KernelConfig + +# Generate from specific config +config = KernelConfig(tile_m=256, tile_n=256, tile_k=64) +codegen = CodegenRunner() +result = codegen.generate_from_config(config) + +# Generate variant +result = codegen.generate("preshuffle") + +# Generate all +results = codegen.generate_all() +``` + +## Command Line Options + +| Option | Values | Description | +|--------|--------|-------------| +| `--output-dir` | path | Output directory | +| `--datatype` | `fp16`, `bf16`, `fp32`, `int8` | Data type | +| `--layout` | `rcr`, `rrr`, `crr`, `ccr` | Matrix layouts | +| `--gpu-target` | `gfx942`, `gfx90a`, `gfx950` | Target GPU | +| `--variants` | `standard`, `preshuffle`, `multi_d` | Kernel variants | +| `--preselected` | `fp16_rcr_essential`, etc. | Predefined kernel set | + +### Layout Notation + +- `R` = Row-major, `C` = Column-major +- Order: A, B, C (e.g., `rcr` = A row, B col, C row) + +## Variants + +### Standard +Basic GEMM: `C = A × B` + +### PreShuffle +Optimized weight access with LDS pre-shuffling. Best for large matrices. + +### Multi-D +Element-wise fusion: `C = op(A × B + D0 + D1 + ...)` + +Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh` + +## Output Structure + +``` +generated_kernels/ +├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp +├── gemm_fp16_rcr_compv4_..._preshuffle.hpp +├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp +└── ... +``` + +## Configuration Files + +### arch_specs.json + +GPU architecture specifications (single source of truth): + +```json +{ + "architectures": { + "gfx942": { + "family": "cdna3", + "warp_size": 64, + "warp_configs": [[2, 2, 1], [4, 4, 1]], + ... + } + } +} +``` + +### preselected_kernels.py + +Curated kernel sets for common use cases. + +## Adding New GPU Support + +See [ADDING_NEW_GPU.md](ADDING_NEW_GPU.md) for complete guide. + +Quick steps: +1. Edit `arch_specs.json` +2. Run `python generate_arch_specs.py` +3. Rebuild + +## Troubleshooting + +| Issue | Solution | +|-------|----------| +| "Arguments not supported" | Check tile config validity | +| Missing element-wise op | Check `elementwise_ops.hpp` | +| Compilation errors | Verify C++17, include paths | + +--- + +> **More info:** See [../README.md](../README.md) for full documentation. diff --git a/dispatcher/codegen/arch_filter.py b/dispatcher/codegen/arch_filter.py new file mode 100644 index 0000000000..67f146045b --- /dev/null +++ b/dispatcher/codegen/arch_filter.py @@ -0,0 +1,1012 @@ +#!/usr/bin/env python + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Architecture-Specific Kernel Filtering for CK Tile Dispatcher + +Unified filtering mechanism for validating kernel configurations against +GPU architecture capabilities. Uses arch_specs.json as single source of truth. + +Key Features: +- GPU architecture-specific warp tile and warp configuration validation +- Data type compatibility checking +- Trait combination validation (pipeline, epilogue, scheduler) +- LDS capacity validation +- Single source of truth (arch_specs.json) + +Usage: + from arch_filter import ArchFilter, get_supported_archs + + # Create filter for specific architecture + filter = ArchFilter("gfx942") + + # Validate a kernel configuration + is_valid = filter.is_kernel_valid( + datatype_a="fp16", datatype_b="fp16", datatype_c="fp16", + tile_m=256, tile_n=256, tile_k=64, + warp_m=2, warp_n=2, warp_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", epilogue="cshuffle", scheduler="intrawave" + ) + + # Get detailed validation results + result = filter.validate_kernel_detailed(...) + print(result.valid, result.errors) +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Any +from enum import Enum +import logging + +logger = logging.getLogger(__name__) + + +class OperatorType(Enum): + """Supported operator types for kernel validation""" + + GEMM = "gemm" + GEMM_PRESHUFFLE = "gemm_preshuffle" + GEMM_MULTI_D = "gemm_multi_d" + CONV_FWD = "conv_fwd" + CONV_BWD_DATA = "conv_bwd_data" + CONV_BWD_WEIGHT = "conv_bwd_weight" + CONV3D_FWD = "conv3d_fwd" + CONV3D_BWD_DATA = "conv3d_bwd_data" + CONV3D_BWD_WEIGHT = "conv3d_bwd_weight" + + +# Operator-specific tile constraints +# Different operators may have different minimum tile sizes or alignment requirements +OPERATOR_TILE_CONSTRAINTS = { + OperatorType.GEMM: { + "min_tile_m": 16, + "min_tile_n": 16, + "min_tile_k": 8, + "tile_m_alignment": 16, + "tile_n_alignment": 16, + "tile_k_alignment": 8, + }, + OperatorType.GEMM_PRESHUFFLE: { + "min_tile_m": 64, + "min_tile_n": 64, + "min_tile_k": 32, + "tile_m_alignment": 32, + "tile_n_alignment": 32, + "tile_k_alignment": 16, + }, + OperatorType.GEMM_MULTI_D: { + "min_tile_m": 16, + "min_tile_n": 16, + "min_tile_k": 8, + "tile_m_alignment": 16, + "tile_n_alignment": 16, + "tile_k_alignment": 8, + }, + OperatorType.CONV_FWD: { + "min_tile_m": 1, # N dimension can be 1 + "min_tile_n": 16, # K (output channels) should be reasonable + "min_tile_k": 16, # C (input channels) should be reasonable + "tile_m_alignment": 1, + "tile_n_alignment": 16, + "tile_k_alignment": 16, + }, + OperatorType.CONV_BWD_DATA: { + "min_tile_m": 1, + "min_tile_n": 16, # C (input channels) + "min_tile_k": 16, # K (output channels) + "tile_m_alignment": 1, + "tile_n_alignment": 16, + "tile_k_alignment": 16, + }, + OperatorType.CONV_BWD_WEIGHT: { + "min_tile_m": 16, # K (output channels) + "min_tile_n": 16, # C (input channels) + "min_tile_k": 1, # Spatial reduction dimension + "tile_m_alignment": 16, + "tile_n_alignment": 16, + "tile_k_alignment": 1, + }, +} + +# Add 3D convolution constraints (same as 2D for now) +OPERATOR_TILE_CONSTRAINTS[OperatorType.CONV3D_FWD] = OPERATOR_TILE_CONSTRAINTS[ + OperatorType.CONV_FWD +] +OPERATOR_TILE_CONSTRAINTS[OperatorType.CONV3D_BWD_DATA] = OPERATOR_TILE_CONSTRAINTS[ + OperatorType.CONV_BWD_DATA +] +OPERATOR_TILE_CONSTRAINTS[OperatorType.CONV3D_BWD_WEIGHT] = OPERATOR_TILE_CONSTRAINTS[ + OperatorType.CONV_BWD_WEIGHT +] + +# ============================================================================= +# Import from Generated Module (Single Source of Truth) +# ============================================================================= + +# Try to import from the generated module (created from arch_specs.json) +try: + from arch_specs_generated import ( + ARCH_FAMILY_MAP, + ELEMENT_SIZE_MAP, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS, + PRESHUFFLE_PIPELINES, + LDS_CAPACITY_LIMITS, + TRAIT_UNSUPPORTED_COMBINATIONS, + DTYPE_COMBINATIONS, + ) + + _USING_GENERATED = True +except ImportError: + # Fallback to hardcoded values if generated module not available + logger.warning( + "arch_specs_generated.py not found, using fallback values. " + "Run 'python generate_arch_specs.py' to generate." + ) + _USING_GENERATED = False + + # Fallback data (minimal subset for basic operation) + ARCH_FAMILY_MAP = { + "gfx90a": "cdna2", + "gfx942": "cdna3", + "gfx950": "cdna4", + "gfx1201": "rdna4", + } + + ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "int32": 4, + } + + WARP_SUPPORTED_COMBINATIONS = { + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], + } + + WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": { + # Key format: A_B_Acc (e.g., fp16_fp16_fp32 = A/B are fp16, accumulator is fp32) + # These match tile_engine's GEMM_WARP_TILE_SUPPORTED_COMBINATIONS + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + } + + # Preshuffle-specific warp tile combinations (no [4, 64, 16]) + PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + }, + } + + PRESHUFFLE_PIPELINES = ["preshufflev2"] + + LDS_CAPACITY_LIMITS = {"compv4": 32768, "preshufflev2": 32768, "default": 65536} + + TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + } + + DTYPE_COMBINATIONS = { + "fp32_fp32": {"acc": "fp32", "notes": "Full precision"}, + "fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"}, + "bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"}, + "fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"}, + "fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"}, + "bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"}, + "bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"}, + "int8_int8": {"acc": "int32", "notes": "Integer GEMM"}, + "pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"}, + } + + +# ============================================================================= +# GPU Family Enum (for backwards compatibility) +# ============================================================================= + + +class GpuFamily(Enum): + """GPU architecture families""" + + CDNA2 = "cdna2" + CDNA3 = "cdna3" + CDNA4 = "cdna4" + RDNA4 = "rdna4" + + +# ============================================================================= +# Dtype Validation Helpers +# ============================================================================= + + +def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool: + """Check if a dtype combination is valid for GEMM.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + return key in DTYPE_COMBINATIONS + + +def get_dtype_acc(dtype_a: str, dtype_b: str) -> str: + """Get the accumulator type for a dtype combination.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + info = DTYPE_COMBINATIONS.get(key, {"acc": "fp32"}) + return info["acc"] + + +def get_valid_dtype_combos() -> List[str]: + """Get list of all valid dtype combinations.""" + return list(DTYPE_COMBINATIONS.keys()) + + +# ============================================================================= +# Validation Result Types +# ============================================================================= + + +@dataclass +class ValidationResult: + """Result of kernel configuration validation""" + + valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + + def __bool__(self) -> bool: + return self.valid + + def add_error(self, msg: str): + self.errors.append(msg) + self.valid = False + + def add_warning(self, msg: str): + self.warnings.append(msg) + + +@dataclass +class KernelConfig: + """Kernel configuration for validation""" + + # Data types + datatype_a: str + datatype_b: str + datatype_c: str + + # Tile dimensions + tile_m: int + tile_n: int + tile_k: int + + # Warp configuration + warp_m: int + warp_n: int + warp_k: int + + # Warp tile dimensions + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + # Traits + pipeline: str = "compv4" + epilogue: str = "cshuffle" + scheduler: str = "intrawave" + + # Layout (for whole-workgroup cover validation) + layout: str = "rcr" + + # Operator type (affects validation rules) + operator: OperatorType = OperatorType.GEMM + + @property + def dtype_key(self) -> str: + """Generate data type combination key for warp tile lookup. + + Uses accumulator dtype (not output C type) to match the format + used in WARP_TILE_SUPPORTED_COMBINATIONS dictionaries which are + keyed as {datatype_a}_{datatype_b}_{accumulator_dtype}. + """ + acc_dtype = get_dtype_acc(self.datatype_a, self.datatype_b) + return f"{self.datatype_a}_{self.datatype_b}_{acc_dtype}" + + +# ============================================================================= +# Architecture Filter Class +# ============================================================================= + + +class ArchFilter: + """ + Architecture-specific kernel configuration filter. + + Validates kernel configurations against GPU architecture capabilities + to ensure only compatible kernels are registered. + + Example: + filter = ArchFilter("gfx942") + + # Quick validation + if filter.is_kernel_valid(config): + registry.register_kernel(kernel) + + # Detailed validation with error messages + result = filter.validate_kernel(config) + if not result.valid: + for error in result.errors: + print(f"Validation failed: {error}") + """ + + def __init__(self, gpu_arch: str, strict_mode: bool = True): + """ + Initialize architecture filter. + + Args: + gpu_arch: GPU architecture string (e.g., "gfx942", "gfx90a") + strict_mode: If True, unknown configurations are rejected. + If False, unknown configurations pass with warnings. + """ + self.gpu_arch = gpu_arch.lower() + self.strict_mode = strict_mode + self.family = ARCH_FAMILY_MAP.get(self.gpu_arch) + + if self.family is None and strict_mode: + raise ValueError( + f"Unknown GPU architecture: {gpu_arch}. " + f"Supported: {list(ARCH_FAMILY_MAP.keys())}" + ) + + def validate_kernel(self, config: KernelConfig) -> ValidationResult: + """ + Validate a kernel configuration against architecture constraints. + + Validation is performed based on the operator type, as different + operators (GEMM, Conv FWD, Conv BWD) have different constraints. + + Args: + config: Kernel configuration to validate + + Returns: + ValidationResult with valid flag and error/warning messages + """ + result = ValidationResult(valid=True) + + # Operator-specific tile constraint validation + self._validate_operator_constraints(config, result) + if not result.valid and self.strict_mode: + return result + + # Basic sanity checks + self._validate_dimensions(config, result) + if not result.valid and self.strict_mode: + return result + + # Warp configuration validation + self._validate_warp_config(config, result) + + # Warp tile combination validation + self._validate_warp_tile_combo(config, result) + + # Trait combination validation + self._validate_trait_combo(config, result) + + # LDS capacity validation + self._validate_lds_capacity(config, result) + + # Dimension alignment validation + self._validate_dimension_alignment(config, result) + + return result + + def _validate_operator_constraints( + self, config: KernelConfig, result: ValidationResult + ): + """Validate operator-specific tile constraints""" + constraints = OPERATOR_TILE_CONSTRAINTS.get(config.operator) + + if constraints is None: + # Unknown operator - add warning but don't fail + result.add_warning( + f"Unknown operator type: {config.operator}. " + f"Skipping operator-specific validation." + ) + return + + # Validate minimum tile sizes + min_tile_m = constraints.get("min_tile_m", 1) + min_tile_n = constraints.get("min_tile_n", 1) + min_tile_k = constraints.get("min_tile_k", 1) + + if config.tile_m < min_tile_m: + result.add_error( + f"Operator {config.operator.value}: tile_m ({config.tile_m}) " + f"< minimum ({min_tile_m})" + ) + if config.tile_n < min_tile_n: + result.add_error( + f"Operator {config.operator.value}: tile_n ({config.tile_n}) " + f"< minimum ({min_tile_n})" + ) + if config.tile_k < min_tile_k: + result.add_error( + f"Operator {config.operator.value}: tile_k ({config.tile_k}) " + f"< minimum ({min_tile_k})" + ) + + # Validate tile alignment + tile_m_align = constraints.get("tile_m_alignment", 1) + tile_n_align = constraints.get("tile_n_alignment", 1) + tile_k_align = constraints.get("tile_k_alignment", 1) + + if tile_m_align > 1 and config.tile_m % tile_m_align != 0: + result.add_error( + f"Operator {config.operator.value}: tile_m ({config.tile_m}) " + f"must be aligned to {tile_m_align}" + ) + if tile_n_align > 1 and config.tile_n % tile_n_align != 0: + result.add_error( + f"Operator {config.operator.value}: tile_n ({config.tile_n}) " + f"must be aligned to {tile_n_align}" + ) + if tile_k_align > 1 and config.tile_k % tile_k_align != 0: + result.add_error( + f"Operator {config.operator.value}: tile_k ({config.tile_k}) " + f"must be aligned to {tile_k_align}" + ) + + def is_kernel_valid( + self, + datatype_a: str = "fp16", + datatype_b: str = "fp16", + datatype_c: str = "fp16", + tile_m: int = 256, + tile_n: int = 256, + tile_k: int = 64, + warp_m: int = 2, + warp_n: int = 2, + warp_k: int = 1, + warp_tile_m: int = 32, + warp_tile_n: int = 32, + warp_tile_k: int = 16, + pipeline: str = "compv4", + epilogue: str = "cshuffle", + scheduler: str = "intrawave", + layout: str = "rcr", + operator: Optional[OperatorType] = None, + ) -> bool: + """ + Quick validation check for a kernel configuration. + + Args: + datatype_a, datatype_b, datatype_c: Data types for A, B, C matrices + tile_m, tile_n, tile_k: Block tile dimensions + warp_m, warp_n, warp_k: Warp/wave configuration + warp_tile_m, warp_tile_n, warp_tile_k: Warp tile dimensions + pipeline, epilogue, scheduler: Kernel traits + layout: Matrix layout (e.g., "rcr") + operator: Operator type (GEMM, CONV_FWD, CONV_BWD_DATA, etc.) + Affects validation rules for tile constraints. + Defaults to GEMM if not specified. + + Returns: + True if configuration is valid for this architecture + """ + config = KernelConfig( + datatype_a=datatype_a.lower(), + datatype_b=datatype_b.lower(), + datatype_c=datatype_c.lower(), + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + pipeline=pipeline.lower(), + epilogue=epilogue.lower(), + scheduler=scheduler.lower(), + layout=layout.lower(), + operator=operator if operator is not None else OperatorType.GEMM, + ) + return self.validate_kernel(config).valid + + def _validate_dimensions(self, config: KernelConfig, result: ValidationResult): + """Validate basic dimension constraints""" + if config.tile_m <= 0 or config.tile_n <= 0 or config.tile_k <= 0: + result.add_error( + f"Tile dimensions must be positive: " + f"{config.tile_m}x{config.tile_n}x{config.tile_k}" + ) + + if config.warp_m <= 0 or config.warp_n <= 0 or config.warp_k <= 0: + result.add_error( + f"Warp dimensions must be positive: " + f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + ) + + if ( + config.warp_tile_m <= 0 + or config.warp_tile_n <= 0 + or config.warp_tile_k <= 0 + ): + result.add_error( + f"Warp tile dimensions must be positive: " + f"{config.warp_tile_m}x{config.warp_tile_n}x{config.warp_tile_k}" + ) + + # Check warp tiles fit within block tiles + if config.warp_m * config.warp_tile_m > config.tile_m: + result.add_error( + f"warp_m * warp_tile_m ({config.warp_m}*{config.warp_tile_m}=" + f"{config.warp_m * config.warp_tile_m}) > tile_m ({config.tile_m})" + ) + if config.warp_n * config.warp_tile_n > config.tile_n: + result.add_error( + f"warp_n * warp_tile_n ({config.warp_n}*{config.warp_tile_n}=" + f"{config.warp_n * config.warp_tile_n}) > tile_n ({config.tile_n})" + ) + if config.warp_k * config.warp_tile_k > config.tile_k: + result.add_error( + f"warp_k * warp_tile_k ({config.warp_k}*{config.warp_tile_k}=" + f"{config.warp_k * config.warp_tile_k}) > tile_k ({config.tile_k})" + ) + + def _validate_warp_config(self, config: KernelConfig, result: ValidationResult): + """Validate warp configuration against architecture""" + allowed = WARP_SUPPORTED_COMBINATIONS.get(self.gpu_arch, []) + current = [config.warp_m, config.warp_n, config.warp_k] + + if not allowed: + msg = f"No warp configurations defined for {self.gpu_arch}" + if self.strict_mode: + result.add_error(msg) + else: + result.add_warning(msg) + return + + if current not in allowed: + result.add_error( + f"Invalid warp configuration {current} for {self.gpu_arch}. " + f"Allowed: {allowed}" + ) + + def _validate_warp_tile_combo(self, config: KernelConfig, result: ValidationResult): + """Validate warp tile combination against architecture and data types""" + # Use preshuffle-specific warp tiles for preshuffle operator + if config.operator == OperatorType.GEMM_PRESHUFFLE: + gpu_combos = PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS.get( + self.gpu_arch, {} + ) + combo_source = "preshuffle" + else: + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {}) + combo_source = "standard" + + if not gpu_combos: + msg = ( + f"No {combo_source} warp tile combinations defined for {self.gpu_arch}" + ) + if self.strict_mode: + result.add_error(msg) + else: + result.add_warning(msg) + return + + dtype_combos = gpu_combos.get(config.dtype_key, []) + if not dtype_combos: + # Data type combo not explicitly listed - may still be valid + result.add_warning( + f"No {combo_source} warp tile combinations defined for {config.dtype_key} on {self.gpu_arch}" + ) + return + + current = [config.warp_tile_m, config.warp_tile_n, config.warp_tile_k] + if current not in dtype_combos: + result.add_error( + f"Invalid warp tile {current} for {config.dtype_key} on {self.gpu_arch} ({combo_source}). " + f"Allowed: {dtype_combos}" + ) + + def _validate_trait_combo(self, config: KernelConfig, result: ValidationResult): + """Validate trait (pipeline, epilogue, scheduler) combination""" + # Preshuffle requires specific pipelines + if config.operator == OperatorType.GEMM_PRESHUFFLE: + if config.pipeline not in PRESHUFFLE_PIPELINES: + result.add_error( + f"Preshuffle GEMM requires pipeline in {PRESHUFFLE_PIPELINES}, " + f"got {config.pipeline}" + ) + + # Conv backward operations only support compv3/mem pipelines + # (compv4/compv5 have template issues: transpose_tile2d for bwd_weight, + # get_length for bwd_data in ck_tile kernels) + conv_bwd_operators = { + OperatorType.CONV_BWD_DATA, + OperatorType.CONV_BWD_WEIGHT, + OperatorType.CONV3D_BWD_DATA, + OperatorType.CONV3D_BWD_WEIGHT, + } + conv_bwd_supported_pipelines = {"compv3", "mem"} + if config.operator in conv_bwd_operators: + if config.pipeline not in conv_bwd_supported_pipelines: + result.add_error( + f"Conv backward operations require pipeline in " + f"{conv_bwd_supported_pipelines}, got {config.pipeline}. " + f"(compv4/compv5 have ck_tile template compatibility issues)" + ) + + combo = (config.pipeline, config.epilogue, config.scheduler) + if combo in TRAIT_UNSUPPORTED_COMBINATIONS: + result.add_error( + f"Unsupported trait combination: pipeline={config.pipeline}, " + f"epilogue={config.epilogue}, scheduler={config.scheduler}" + ) + + def _validate_lds_capacity(self, config: KernelConfig, result: ValidationResult): + """Validate LDS (Local Data Share) memory capacity""" + elem_size_a = ELEMENT_SIZE_MAP.get(config.datatype_a, 2) + elem_size_b = ELEMENT_SIZE_MAP.get(config.datatype_b, 2) + + matrix_a_size = config.tile_m * config.tile_k * elem_size_a + matrix_b_size = config.tile_n * config.tile_k * elem_size_b + total_lds = matrix_a_size + matrix_b_size + + max_lds = LDS_CAPACITY_LIMITS.get( + config.pipeline, LDS_CAPACITY_LIMITS["default"] + ) + + if total_lds > max_lds: + result.add_error( + f"LDS capacity exceeded: {total_lds} bytes > {max_lds} bytes limit. " + f"Matrix A: {config.tile_m}x{config.tile_k}x{elem_size_a}={matrix_a_size}B, " + f"Matrix B: {config.tile_n}x{config.tile_k}x{elem_size_b}={matrix_b_size}B" + ) + + def _validate_dimension_alignment( + self, config: KernelConfig, result: ValidationResult + ): + """Validate tile dimensions are aligned with warp dimensions""" + if config.tile_m % (config.warp_m * config.warp_tile_m) != 0: + result.add_error( + f"tile_m ({config.tile_m}) must be divisible by " + f"warp_m*warp_tile_m ({config.warp_m}*{config.warp_tile_m}=" + f"{config.warp_m * config.warp_tile_m})" + ) + + if config.tile_n % (config.warp_n * config.warp_tile_n) != 0: + result.add_error( + f"tile_n ({config.tile_n}) must be divisible by " + f"warp_n*warp_tile_n ({config.warp_n}*{config.warp_tile_n}=" + f"{config.warp_n * config.warp_tile_n})" + ) + + if config.tile_k % (config.warp_k * config.warp_tile_k) != 0: + result.add_error( + f"tile_k ({config.tile_k}) must be divisible by " + f"warp_k*warp_tile_k ({config.warp_k}*{config.warp_tile_k}=" + f"{config.warp_k * config.warp_tile_k})" + ) + + def get_supported_warp_configs(self) -> List[List[int]]: + """Get list of supported warp configurations for this architecture""" + return WARP_SUPPORTED_COMBINATIONS.get(self.gpu_arch, []) + + def get_supported_warp_tiles(self, dtype_key: str) -> List[List[int]]: + """Get list of supported warp tile configurations for given data types""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {}) + return gpu_combos.get(dtype_key, []) + + def get_supported_datatypes(self) -> List[str]: + """Get list of data type combinations supported on this architecture""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {}) + return list(gpu_combos.keys()) + + +# ============================================================================= +# Registry Filter Integration +# ============================================================================= + + +class RegistryFilter: + """ + Filter wrapper for integrating with dispatcher Registry. + + Provides a callable interface that can be used with Registry.filter() + or during kernel registration. + + Example: + # Create filter for gfx942 + filter = RegistryFilter("gfx942") + + # Use with registry + registry = Registry() + registry.set_kernel_filter(filter) # Auto-filter on registration + + # Or filter existing kernels + valid_kernels = registry.filter(filter.accepts_kernel) + """ + + def __init__(self, gpu_arch: str, strict_mode: bool = False): + """ + Initialize registry filter. + + Args: + gpu_arch: Target GPU architecture + strict_mode: If True, reject unknown configurations + """ + self.arch_filter = ArchFilter(gpu_arch, strict_mode=strict_mode) + self.gpu_arch = gpu_arch + self._rejected_count = 0 + self._accepted_count = 0 + + def accepts_kernel(self, kernel_config: Dict[str, Any]) -> bool: + """ + Check if a kernel configuration should be accepted into the registry. + + Args: + kernel_config: Dictionary with kernel configuration values + + Returns: + True if kernel is valid for target architecture + """ + try: + is_valid = self.arch_filter.is_kernel_valid( + datatype_a=kernel_config.get("dtype_a", "fp16"), + datatype_b=kernel_config.get("dtype_b", "fp16"), + datatype_c=kernel_config.get("dtype_c", "fp16"), + tile_m=kernel_config.get("tile_m", 256), + tile_n=kernel_config.get("tile_n", 256), + tile_k=kernel_config.get("tile_k", 64), + warp_m=kernel_config.get("warp_m", 2), + warp_n=kernel_config.get("warp_n", 2), + warp_k=kernel_config.get("warp_k", 1), + warp_tile_m=kernel_config.get("warp_tile_m", 32), + warp_tile_n=kernel_config.get("warp_tile_n", 32), + warp_tile_k=kernel_config.get("warp_tile_k", 16), + pipeline=kernel_config.get("pipeline", "compv4"), + epilogue=kernel_config.get("epilogue", "cshuffle"), + scheduler=kernel_config.get("scheduler", "intrawave"), + layout=kernel_config.get("layout", "rcr"), + ) + + if is_valid: + self._accepted_count += 1 + else: + self._rejected_count += 1 + + return is_valid + + except Exception as e: + logger.warning(f"Error validating kernel config: {e}") + self._rejected_count += 1 + return False + + def get_stats(self) -> Dict[str, int]: + """Get filtering statistics""" + return { + "accepted": self._accepted_count, + "rejected": self._rejected_count, + "total": self._accepted_count + self._rejected_count, + } + + def reset_stats(self): + """Reset filtering statistics""" + self._accepted_count = 0 + self._rejected_count = 0 + + def __call__(self, kernel_config: Dict[str, Any]) -> bool: + """Callable interface for use with filter functions""" + return self.accepts_kernel(kernel_config) + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + + +def get_supported_archs() -> List[str]: + """Get list of all supported GPU architectures""" + return list(ARCH_FAMILY_MAP.keys()) + + +def get_arch_family(gpu_arch: str) -> Optional[str]: + """Get the GPU family for an architecture""" + family = ARCH_FAMILY_MAP.get(gpu_arch.lower()) + return family if family else None # ARCH_FAMILY_MAP contains strings, not Enums + + +def create_filter_for_current_gpu() -> Optional[ArchFilter]: + """ + Create a filter for the current GPU (auto-detect). + + Returns: + ArchFilter for detected GPU, or None if detection fails + """ + try: + import subprocess + + result = subprocess.run(["rocminfo"], capture_output=True, text=True, timeout=5) + + for line in result.stdout.split("\n"): + if "gfx" in line.lower(): + for arch in ARCH_FAMILY_MAP.keys(): + if arch in line.lower(): + return ArchFilter(arch) + + return None + except Exception: + return None + + +def filter_kernel_list( + kernels: List[Dict[str, Any]], gpu_arch: str +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Filter a list of kernel configurations for a specific architecture. + + Args: + kernels: List of kernel configuration dictionaries + gpu_arch: Target GPU architecture + + Returns: + Tuple of (valid_kernels, rejected_kernels) + """ + reg_filter = RegistryFilter(gpu_arch) + valid = [] + rejected = [] + + for kernel in kernels: + if reg_filter.accepts_kernel(kernel): + valid.append(kernel) + else: + rejected.append(kernel) + + return valid, rejected + + +# ============================================================================= +# Main (for testing) +# ============================================================================= + +if __name__ == "__main__": + # Test the filter + print("Testing ArchFilter for gfx942...\n") + + filter_942 = ArchFilter("gfx942") + + # Test valid configuration + print("Test 1: Valid FP16 GEMM kernel") + result = filter_942.validate_kernel( + KernelConfig( + datatype_a="fp16", + datatype_b="fp16", + datatype_c="fp16", + tile_m=256, + tile_n=256, + tile_k=64, + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + ) + ) + print(f" Valid: {result.valid}") + if result.errors: + print(f" Errors: {result.errors}") + print() + + # Test invalid warp configuration + print("Test 2: Invalid warp configuration") + result = filter_942.validate_kernel( + KernelConfig( + datatype_a="fp16", + datatype_b="fp16", + datatype_c="fp16", + tile_m=256, + tile_n=256, + tile_k=64, + warp_m=3, + warp_n=3, + warp_k=1, # Invalid! + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + ) + ) + print(f" Valid: {result.valid}") + if result.errors: + print(f" Errors: {result.errors}") + print() + + # Test LDS overflow + print("Test 3: LDS capacity overflow") + result = filter_942.validate_kernel( + KernelConfig( + datatype_a="fp16", + datatype_b="fp16", + datatype_c="fp16", + tile_m=512, + tile_n=512, + tile_k=256, # Too large! + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + ) + ) + print(f" Valid: {result.valid}") + if result.errors: + print(f" Errors: {result.errors}") + print() + + # Test quick validation + print("Test 4: Quick validation (is_kernel_valid)") + is_valid = filter_942.is_kernel_valid( + tile_m=128, + tile_n=128, + tile_k=32, + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=16, + ) + print(f" Valid: {is_valid}") + print() + + # Show supported configurations + print("Supported warp configurations for gfx942:") + for cfg in filter_942.get_supported_warp_configs(): + print(f" {cfg}") + print() + + print("Supported data types for gfx942:") + for dtype in filter_942.get_supported_datatypes(): + print(f" {dtype}") diff --git a/dispatcher/codegen/arch_specs.json b/dispatcher/codegen/arch_specs.json new file mode 100644 index 0000000000..7d8c83fbf7 --- /dev/null +++ b/dispatcher/codegen/arch_specs.json @@ -0,0 +1,270 @@ +{ + "_comment": "Single source of truth for GPU architecture specifications. Edit this file to add new GPU support.", + "_version": "1.2.0", + "_instructions": "See ADDING_NEW_GPU.md for instructions on adding new GPU support.", + "_supported_arch_note": "CK Tile supports: GFX9 (gfx908, gfx90a, gfx942, gfx950), GFX10.3 (gfx103x), GFX11 (gfx110x, gfx115x), GFX12 (gfx120x)", + + "architectures": { + "gfx908": { + "family": "cdna1", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI100", + "warp_size": 64, + "lds_capacity_kb": 64, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]] + } + }, + + "gfx90a": { + "family": "cdna2", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI200 series", + "warp_size": 64, + "lds_capacity_kb": 64, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]] + } + }, + + "gfx942": { + "family": "cdna3", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI300 series", + "warp_size": 64, + "lds_capacity_kb": 64, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]], + "bf8_fp8_fp32": [[32, 32, 16]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]] + } + }, + + "gfx950": { + "family": "cdna4", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI350 series", + "warp_size": 64, + "lds_capacity_kb": 160, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + "fp8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 128], [32, 32, 64]], + "bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + "pk_fp4_pk_fp4_fp32": [[16, 16, 128]] + } + }, + + "gfx1100": { + "family": "rdna3", + "target_family": "gfx11", + "architecture": "rdna", + "description": "AMD Radeon RX 7900 series (RDNA3)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]] + } + }, + + "gfx1200": { + "family": "rdna4", + "target_family": "gfx12", + "architecture": "rdna", + "description": "AMD Radeon RX 9000 series (RDNA4)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]] + } + }, + + "gfx1201": { + "family": "rdna4", + "target_family": "gfx12", + "architecture": "rdna", + "description": "AMD Radeon RX 9000 series (RDNA4)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]] + } + } + }, + + "element_sizes": { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "pk_fp4": 0.5, + "int32": 4 + }, + + "datatype_cpp_map": { + "_comment": "Maps dtype string to CK Tile C++ type for code generation", + "fp16": "ck_tile::half_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp64": "double", + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int8": "ck_tile::int8_t", + "int4": "ck_tile::pk_int4_t", + "pk_fp4": "ck_tile::pk_fp4_t", + "int32": "ck_tile::int32_t" + }, + + "dtype_combinations": { + "_comment": "All valid (A, B) -> Acc combinations for GEMM from warp_gemm_dispatcher.hpp", + "fp32_fp32": {"acc": "fp32", "notes": "Full precision"}, + "fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"}, + "bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"}, + "fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"}, + "fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"}, + "bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"}, + "bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"}, + "int8_int8": {"acc": "int32", "notes": "Integer GEMM"}, + "pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"} + }, + + "layout_cpp_map": { + "_comment": "Maps layout character to CK Tile C++ type", + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor" + }, + + "pipeline_lds_limits": { + "_comment": "LDS capacity limits in bytes for different pipeline types", + "mem": 65536, + "compv1": 65536, + "compv2": 65536, + "compv3": 65536, + "compv4": 32768, + "compv5": 65536, + "preshufflev1": 32768, + "preshufflev2": 32768, + "default": 65536 + }, + + "unsupported_trait_combos": { + "_comment": "Only 'mem' pipeline supports interwave scheduler. All compute pipelines only support intrawave.", + "combinations": [ + ["compv3", "cshuffle", "interwave"], + ["compv3", "default", "interwave"], + ["compv4", "cshuffle", "interwave"], + ["compv4", "default", "interwave"], + ["compv5", "cshuffle", "interwave"], + ["compv5", "default", "interwave"], + ["compv6", "cshuffle", "interwave"], + ["compv6", "default", "interwave"], + ["comp_async", "cshuffle", "interwave"], + ["comp_async", "default", "interwave"] + ] + }, + + "preshuffle_warp_tile_combos": { + "_comment": "Preshuffle-specific warp tile combinations (subset of standard GEMM, no [4, 64, 16])", + "gfx90a": { + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]] + }, + "gfx942": { + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]] + }, + "gfx950": { + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] + } + }, + + "preshuffle_pipelines": { + "_comment": "Pipelines supported for preshuffle GEMM variant", + "supported": ["preshufflev2"] + } +} diff --git a/dispatcher/codegen/arch_specs_generated.py b/dispatcher/codegen/arch_specs_generated.py new file mode 100644 index 0000000000..97f17e9724 --- /dev/null +++ b/dispatcher/codegen/arch_specs_generated.py @@ -0,0 +1,358 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + +Generated from: arch_specs.json +Generated at: 2026-01-05T19:34:01.224422 + +To update this file: +1. Edit arch_specs.json +2. Run: python generate_arch_specs.py + +This module provides architecture-specific configurations for kernel filtering. +""" + +from typing import Dict, List, Set, Tuple + +# ============================================================================= +# Architecture Data (Generated from arch_specs.json) +# ============================================================================= + +# GPU architecture to family mapping +ARCH_FAMILY_MAP: Dict[str, str] = { + "gfx908": "cdna1", + "gfx90a": "cdna2", + "gfx942": "cdna3", + "gfx950": "cdna4", + "gfx1100": "rdna3", + "gfx1200": "rdna4", + "gfx1201": "rdna4", +} + +# Element size in bytes for each data type +ELEMENT_SIZE_MAP: Dict[str, float] = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "pk_fp4": 0.5, + "int32": 4, +} + +# Supported warp configurations per architecture [warp_m, warp_n, warp_k] +WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = { + "gfx908": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx1100": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], + "gfx1200": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], + "gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], +} + +# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...] +WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = { + "gfx908": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + }, + "gfx90a": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + }, + "gfx942": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]], + "bf8_fp8_fp32": [[32, 32, 16]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + }, + "gfx950": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "fp8_bf8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]], + "bf8_bf8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + "pk_fp4_pk_fp4_fp32": [[16, 16, 128]], + }, + "gfx1100": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]], + }, + "gfx1200": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]], + }, + "gfx1201": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]], + }, +} + +# Preshuffle-specific warp tile combinations (subset of standard GEMM) +PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = { + "gfx90a": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +# Preshuffle-supported pipelines +PRESHUFFLE_PIPELINES: List[str] = ["preshufflev2"] + +# LDS capacity limits per pipeline type (in bytes) +LDS_CAPACITY_LIMITS: Dict[str, int] = { + "mem": 65536, + "compv1": 65536, + "compv2": 65536, + "compv3": 65536, + "compv4": 32768, + "compv5": 65536, + "preshufflev1": 32768, + "preshufflev2": 32768, + "default": 65536, +} + +# Unsupported trait combinations: (pipeline, epilogue, scheduler) +TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + ("compv5", "cshuffle", "interwave"), + ("compv5", "default", "interwave"), + ("compv6", "cshuffle", "interwave"), + ("compv6", "default", "interwave"), + ("comp_async", "cshuffle", "interwave"), + ("comp_async", "default", "interwave"), +} + +# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes +DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = { + "fp32_fp32": {"acc": "fp32", "notes": "Full precision"}, + "fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"}, + "bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"}, + "fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"}, + "fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"}, + "bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"}, + "bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"}, + "int8_int8": {"acc": "int32", "notes": "Integer GEMM"}, + "pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"}, +} + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def get_supported_archs() -> List[str]: + """Get list of all supported GPU architectures.""" + return list(ARCH_FAMILY_MAP.keys()) + + +def get_arch_family(gpu_arch: str) -> str: + """Get the GPU family for an architecture.""" + return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown") + + +def get_element_size(dtype: str) -> float: + """Get element size in bytes for a data type.""" + return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0) + + +def get_warp_configs(gpu_arch: str) -> List[List[int]]: + """Get supported warp configurations for an architecture.""" + return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), []) + + +def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]: + """Get supported warp tile combinations for arch and data types.""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {}) + return gpu_combos.get(dtype_key.lower(), []) + + +def get_lds_limit(pipeline: str) -> int: + """Get LDS capacity limit for a pipeline type.""" + return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"]) + + +def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool: + """Check if a trait combination is unsupported.""" + return ( + pipeline.lower(), + epilogue.lower(), + scheduler.lower(), + ) in TRAIT_UNSUPPORTED_COMBINATIONS + + +def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]: + """Get accumulator type and notes for a dtype combination.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + return DTYPE_COMBINATIONS.get(key, {"acc": "fp32", "notes": "unknown"}) + + +def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool: + """Check if a dtype combination is valid.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + return key in DTYPE_COMBINATIONS + + +def get_valid_dtype_combos() -> List[str]: + """Get list of all valid dtype combinations.""" + return list(DTYPE_COMBINATIONS.keys()) diff --git a/dispatcher/codegen/default_config.json b/dispatcher/codegen/default_config.json new file mode 100644 index 0000000000..3ef823fcc2 --- /dev/null +++ b/dispatcher/codegen/default_config.json @@ -0,0 +1,27 @@ +{ + "tile_config": { + "tile_m": [128, 256], + "tile_n": [128, 256], + "tile_k": [32, 64], + "warp_m": [2, 4], + "warp_n": [2, 4], + "warp_k": [1], + "warp_tile_m": [16, 32], + "warp_tile_n": [16, 32], + "warp_tile_k": [16] + }, + "trait_config": { + "pipeline": ["compv4"], + "epilogue": ["cshuffle"], + "scheduler": ["intrawave"], + "pad_m": [false], + "pad_n": [false], + "pad_k": [false], + "persistent": [false, true] + }, + "multi_d_config": { + "elementwise_ops": ["MultiDAdd", "Relu", "Gelu"], + "num_d_tensors": [1, 2] + } +} + diff --git a/dispatcher/codegen/generate_arch_specs.py b/dispatcher/codegen/generate_arch_specs.py new file mode 100644 index 0000000000..5b6fc2971b --- /dev/null +++ b/dispatcher/codegen/generate_arch_specs.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Architecture Specs Generator + +Generates both Python and C++ code from a single JSON source of truth. +This ensures consistency between Python codegen and C++ runtime filtering. + +Usage: + python generate_arch_specs.py [--json arch_specs.json] [--output-dir .] + + # Regenerate after editing arch_specs.json: + python generate_arch_specs.py + +Output: + - arch_specs_generated.py (Python module with arch data) + - arch_specs_generated.hpp (C++ header with arch data) +""" + +import json +import argparse +from pathlib import Path +from datetime import datetime +from typing import Dict, Any + +SCRIPT_DIR = Path(__file__).parent + + +def load_arch_specs(json_path: Path) -> Dict[str, Any]: + """Load architecture specifications from JSON file.""" + with open(json_path) as f: + return json.load(f) + + +def generate_python_module(specs: Dict[str, Any], output_path: Path): + """Generate Python module from arch specs.""" + + timestamp = datetime.now().isoformat() + + # Extract data + archs = specs["architectures"] + element_sizes = specs["element_sizes"] + pipeline_limits = specs["pipeline_lds_limits"] + unsupported = specs["unsupported_trait_combos"]["combinations"] + + # Build warp configs dict + warp_configs_str = "{\n" + for arch, data in archs.items(): + warp_configs_str += f' "{arch}": {data["warp_configs"]},\n' + warp_configs_str += "}" + + # Build warp tile combos dict + warp_tile_str = "{\n" + for arch, data in archs.items(): + warp_tile_str += f' "{arch}": {{\n' + for dtype, combos in data["warp_tile_combos"].items(): + warp_tile_str += f' "{dtype}": {combos},\n' + warp_tile_str += " },\n" + warp_tile_str += "}" + + # Build arch family map + arch_family_str = "{\n" + for arch, data in archs.items(): + arch_family_str += f' "{arch}": "{data["family"]}",\n' + arch_family_str += "}" + + # Build unsupported combos set + unsupported_str = "{\n" + for combo in unsupported: + unsupported_str += f' ("{combo[0]}", "{combo[1]}", "{combo[2]}"),\n' + unsupported_str += "}" + + # Pipeline LDS limits + pipeline_limits_clean = { + k: v for k, v in pipeline_limits.items() if not k.startswith("_") + } + + # Build dtype combinations dict + dtype_combos = specs.get("dtype_combinations", {}) + dtype_combos_str = "{\n" + for key, info in dtype_combos.items(): + if not key.startswith("_"): + dtype_combos_str += f' "{key}": {{"acc": "{info["acc"]}", "notes": "{info["notes"]}"}},\n' + dtype_combos_str += "}" + + # Build preshuffle warp tile combos dict (operator-specific) + preshuffle_combos = specs.get("preshuffle_warp_tile_combos", {}) + preshuffle_warp_tile_str = "{\n" + for arch, dtype_combos_dict in preshuffle_combos.items(): + if not arch.startswith("_"): + preshuffle_warp_tile_str += f' "{arch}": {{\n' + for dtype, combos in dtype_combos_dict.items(): + preshuffle_warp_tile_str += f' "{dtype}": {combos},\n' + preshuffle_warp_tile_str += " },\n" + preshuffle_warp_tile_str += "}" + + # Build preshuffle pipelines list + preshuffle_pipelines = specs.get("preshuffle_pipelines", {}).get( + "supported", ["preshufflev2"] + ) + preshuffle_pipelines_str = str(preshuffle_pipelines) + + content = f'''# SPDX-License-Identifier: MIT + +""" +AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + +Generated from: arch_specs.json +Generated at: {timestamp} + +To update this file: +1. Edit arch_specs.json +2. Run: python generate_arch_specs.py + +This module provides architecture-specific configurations for kernel filtering. +""" + +from typing import Dict, List, Set, Tuple + +# ============================================================================= +# Architecture Data (Generated from arch_specs.json) +# ============================================================================= + +# GPU architecture to family mapping +ARCH_FAMILY_MAP: Dict[str, str] = {arch_family_str} + +# Element size in bytes for each data type +ELEMENT_SIZE_MAP: Dict[str, float] = {element_sizes} + +# Supported warp configurations per architecture [warp_m, warp_n, warp_k] +WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {warp_configs_str} + +# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...] +WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {warp_tile_str} + +# Preshuffle-specific warp tile combinations (subset of standard GEMM) +PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {preshuffle_warp_tile_str} + +# Preshuffle-supported pipelines +PRESHUFFLE_PIPELINES: List[str] = {preshuffle_pipelines_str} + +# LDS capacity limits per pipeline type (in bytes) +LDS_CAPACITY_LIMITS: Dict[str, int] = {pipeline_limits_clean} + +# Unsupported trait combinations: (pipeline, epilogue, scheduler) +TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = {unsupported_str} + +# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes +DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = {dtype_combos_str} + +# ============================================================================= +# Helper Functions +# ============================================================================= + +def get_supported_archs() -> List[str]: + """Get list of all supported GPU architectures.""" + return list(ARCH_FAMILY_MAP.keys()) + + +def get_arch_family(gpu_arch: str) -> str: + """Get the GPU family for an architecture.""" + return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown") + + +def get_element_size(dtype: str) -> float: + """Get element size in bytes for a data type.""" + return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0) + + +def get_warp_configs(gpu_arch: str) -> List[List[int]]: + """Get supported warp configurations for an architecture.""" + return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), []) + + +def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]: + """Get supported warp tile combinations for arch and data types.""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {{}}) + return gpu_combos.get(dtype_key.lower(), []) + + +def get_lds_limit(pipeline: str) -> int: + """Get LDS capacity limit for a pipeline type.""" + return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"]) + + +def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool: + """Check if a trait combination is unsupported.""" + return (pipeline.lower(), epilogue.lower(), scheduler.lower()) in TRAIT_UNSUPPORTED_COMBINATIONS + + +def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]: + """Get accumulator type and notes for a dtype combination.""" + key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}" + return DTYPE_COMBINATIONS.get(key, {{"acc": "fp32", "notes": "unknown"}}) + + +def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool: + """Check if a dtype combination is valid.""" + key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}" + return key in DTYPE_COMBINATIONS + + +def get_valid_dtype_combos() -> List[str]: + """Get list of all valid dtype combinations.""" + return list(DTYPE_COMBINATIONS.keys()) +''' + + output_path.write_text(content) + print(f"Generated: {output_path}") + + +def generate_cpp_header(specs: Dict[str, Any], output_path: Path): + """Generate C++ header from arch specs.""" + + timestamp = datetime.now().isoformat() + + # Extract data + archs = specs["architectures"] + element_sizes = specs["element_sizes"] + pipeline_limits = specs["pipeline_lds_limits"] + specs["unsupported_trait_combos"]["combinations"] + + # Build arch enum and string functions + arch_enums = [] + arch_to_string_cases = [] + string_to_arch_cases = [] + + for arch, data in archs.items(): + enum_name = arch.upper().replace("GFX", "GFX_") + arch_enums.append(f" {enum_name}, // {data['description']}") + arch_to_string_cases.append( + f' case GpuArch::{enum_name}: return "{arch}";' + ) + string_to_arch_cases.append( + f' if (arch_str == "{arch}") return GpuArch::{enum_name};' + ) + + # Build warp configs switch + warp_config_cases = [] + for arch, data in archs.items(): + enum_name = arch.upper().replace("GFX", "GFX_") + configs = ", ".join( + [f"{{{c[0]}, {c[1]}, {c[2]}}}" for c in data["warp_configs"]] + ) + warp_config_cases.append( + f" case GpuArch::{enum_name}: return {{{configs}}};" + ) + + # Build element size switch + # Include all data types defined in kernel_key.hpp DataType enum + elem_size_cases = [] + dtype_enum_map = { + "fp16": "FP16", + "bf16": "BF16", + "fp32": "FP32", + "fp64": "FP64", + "fp8": "FP8", + "bf8": "BF8", + "int8": "INT8", + "int4": "INT4", + "int32": "INT32", + } + for dtype, size in element_sizes.items(): + if dtype in dtype_enum_map: + elem_size_cases.append( + f" case DataType::{dtype_enum_map[dtype]}: return {float(size)}f;" + ) + + # Build LDS limits + lds_limit_cases = [] + pipeline_enum_map = { + "mem": "Mem", + "compv1": "CompV1", + "compv2": "CompV2", + "compv3": "CompV3", + "compv4": "CompV4", + "compv5": "CompV5", + "preshufflev1": "PreShuffleV1", + "preshufflev2": "PreShuffleV2", + } + default_lds = pipeline_limits.get("default", 65536) + for pipeline, limit in pipeline_limits.items(): + if pipeline in pipeline_enum_map: + lds_limit_cases.append( + f" if (pipeline == Pipeline::{pipeline_enum_map[pipeline]}) return {limit};" + ) + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +/** + * AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + * + * Generated from: arch_specs.json + * Generated at: {timestamp} + * + * To update this file: + * 1. Edit arch_specs.json + * 2. Run: python generate_arch_specs.py + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include + +namespace ck_tile {{ +namespace dispatcher {{ +namespace arch_specs {{ + +// ============================================================================= +// GPU Architecture Enum (Generated) +// ============================================================================= + +enum class GpuArch : std::uint8_t {{ +{chr(10).join(arch_enums)} + UNKNOWN +}}; + +// ============================================================================= +// String Conversion Functions (Generated) +// ============================================================================= + +inline std::string arch_to_string(GpuArch arch) {{ + switch (arch) {{ +{chr(10).join(arch_to_string_cases)} + default: return "unknown"; + }} +}} + +inline GpuArch string_to_arch(const std::string& arch_str) {{ +{chr(10).join(string_to_arch_cases)} + return GpuArch::UNKNOWN; +}} + +// ============================================================================= +// Element Size (Generated) +// ============================================================================= + +inline float element_size(DataType dtype) {{ + switch (dtype) {{ +{chr(10).join(elem_size_cases)} + default: return 2.0f; + }} +}} + +// ============================================================================= +// Warp Configurations (Generated) +// ============================================================================= + +using WarpConfig = std::array; + +inline std::vector get_supported_warp_configs(GpuArch arch) {{ + switch (arch) {{ +{chr(10).join(warp_config_cases)} + default: return {{}}; + }} +}} + +// ============================================================================= +// LDS Capacity Limits (Generated) +// ============================================================================= + +inline std::size_t get_lds_capacity(Pipeline pipeline) {{ +{chr(10).join(lds_limit_cases)} + return {default_lds}; // Default +}} + +// ============================================================================= +// Unsupported Trait Combinations (Generated) +// ============================================================================= + +inline bool is_trait_unsupported(Pipeline pipeline, [[maybe_unused]] Epilogue epilogue, Scheduler scheduler) {{ + // Generated from unsupported_trait_combos in arch_specs.json + if (scheduler == Scheduler::Interwave) {{ + if (pipeline == Pipeline::CompV3 || pipeline == Pipeline::CompV4) {{ + return true; + }} + }} + return false; +}} + +}} // namespace arch_specs +}} // namespace dispatcher +}} // namespace ck_tile +""" + + output_path.write_text(content) + print(f"Generated: {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate Python and C++ code from arch_specs.json" + ) + parser.add_argument( + "--json", + type=Path, + default=SCRIPT_DIR / "arch_specs.json", + help="Path to arch_specs.json", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=SCRIPT_DIR, + help="Output directory for generated files", + ) + parser.add_argument( + "--cpp-output-dir", + type=Path, + default=None, + help="Output directory for C++ header (defaults to dispatcher/include/...)", + ) + + args = parser.parse_args() + + # Load specs + print(f"Loading: {args.json}") + specs = load_arch_specs(args.json) + + # Generate Python module + py_output = args.output_dir / "arch_specs_generated.py" + generate_python_module(specs, py_output) + + # Generate C++ header + if args.cpp_output_dir: + cpp_output = args.cpp_output_dir / "arch_specs_generated.hpp" + else: + cpp_output = ( + SCRIPT_DIR.parent + / "include" + / "ck_tile" + / "dispatcher" + / "arch_specs_generated.hpp" + ) + + cpp_output.parent.mkdir(parents=True, exist_ok=True) + generate_cpp_header(specs, cpp_output) + + print("\nDone! To apply changes:") + print(" 1. Python code will automatically use arch_specs_generated.py") + print(" 2. C++ code includes arch_specs_generated.hpp") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/codegen/generate_dispatcher_registration.py b/dispatcher/codegen/generate_dispatcher_registration.py new file mode 100644 index 0000000000..024ec4a7c8 --- /dev/null +++ b/dispatcher/codegen/generate_dispatcher_registration.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Generate dispatcher registration code for CK Tile kernels + +This script generates C++ registration code that instantiates TileKernelInstance +templates for each generated kernel, solving the "cannot instantiate from parsed headers" problem. +""" + +import json +import argparse +from pathlib import Path +from typing import List +from dataclasses import dataclass + + +@dataclass +class KernelConfig: + """Kernel configuration for registration""" + + name: str + header_file: str + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + block_size: int + pipeline: str + epilogue: str + scheduler: str + pad_m: bool + pad_n: bool + pad_k: bool + persistent: bool + double_buffer: bool + transpose_c: bool + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout_a: str = "row" + layout_b: str = "col" + layout_c: str = "row" + + +def generate_registration_header(kernels: List[KernelConfig], output_file: Path): + """Generate registration header file""" + + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated by generate_dispatcher_registration.py + +#pragma once + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/backends/kernel_registration.hpp" + +// Include all generated kernel headers +""" + + # Add includes for all kernel headers + for kernel in kernels: + content += f'#include "{kernel.header_file}"\n' + + content += """ + +namespace ck_tile { +namespace dispatcher { +namespace generated { + +/// Register all generated kernels with the dispatcher +inline void register_all_kernels(Registry& registry) +{ +""" + + # Add registration calls for each kernel + for kernel in kernels: + # Extract the SelectedKernel type name from the header file + # Assuming the header defines a type like: using SelectedKernel = ... + kernel_type = f"SelectedKernel_{kernel.name}" + + content += f""" // Register {kernel.name} + register_tile_kernel<{kernel_type}>(registry, "{kernel.name}"); +""" + + content += """} + +/// Register all generated kernels with the global registry +inline void register_all_kernels() +{ + auto& registry = Registry::instance(); + register_all_kernels(registry); +} + +} // namespace generated +} // namespace dispatcher +} // namespace ck_tile +""" + + output_file.write_text(content) + print(f"✓ Generated registration header: {output_file}") + + +def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path): + """Generate registration implementation file""" + + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated by generate_dispatcher_registration.py + +#include "dispatcher_registration.hpp" + +namespace ck_tile { +namespace dispatcher { +namespace generated { + +// Explicit instantiations to reduce compile time +// These ensure the templates are instantiated once + +""" + + for kernel in kernels: + kernel_type = f"SelectedKernel_{kernel.name}" + content += f"template class backends::TileKernelInstance<{kernel_type}>;\n" + + content += """ +} // namespace generated +} // namespace dispatcher +} // namespace ck_tile +""" + + output_file.write_text(content) + print(f"✓ Generated registration implementation: {output_file}") + + +def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path): + """Generate a wrapper header that defines SelectedKernel type""" + + wrapper_file = output_dir / f"{kernel.name}_wrapper.hpp" + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated by generate_dispatcher_registration.py + +#pragma once + +#include "{kernel.header_file}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +// Type alias for dispatcher registration +// This allows the registration code to reference the kernel type +using SelectedKernel_{kernel.name} = /* Actual kernel type from generated header */; + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + + wrapper_file.write_text(content) + + +def load_kernel_manifest(manifest_file: Path) -> List[KernelConfig]: + """Load kernel configurations from manifest file""" + + with open(manifest_file, "r") as f: + data = json.load(f) + + kernels = [] + for kernel_data in data.get("kernels", []): + kernel = KernelConfig( + name=kernel_data["name"], + header_file=kernel_data["header_file"], + tile_m=kernel_data["tile_m"], + tile_n=kernel_data["tile_n"], + tile_k=kernel_data["tile_k"], + warp_m=kernel_data.get("warp_m", 2), + warp_n=kernel_data.get("warp_n", 2), + warp_k=kernel_data.get("warp_k", 1), + warp_tile_m=kernel_data.get("warp_tile_m", 32), + warp_tile_n=kernel_data.get("warp_tile_n", 32), + warp_tile_k=kernel_data.get("warp_tile_k", 16), + block_size=kernel_data.get("block_size", 256), + pipeline=kernel_data.get("pipeline", "compv4"), + epilogue=kernel_data.get("epilogue", "cshuffle"), + scheduler=kernel_data.get("scheduler", "intrawave"), + pad_m=kernel_data.get("pad_m", False), + pad_n=kernel_data.get("pad_n", False), + pad_k=kernel_data.get("pad_k", False), + persistent=kernel_data.get("persistent", False), + double_buffer=kernel_data.get("double_buffer", True), + transpose_c=kernel_data.get("transpose_c", False), + dtype_a=kernel_data.get("dtype_a", "fp16"), + dtype_b=kernel_data.get("dtype_b", "fp16"), + dtype_c=kernel_data.get("dtype_c", "fp16"), + dtype_acc=kernel_data.get("dtype_acc", "fp32"), + ) + kernels.append(kernel) + + return kernels + + +def scan_generated_headers(generated_dir: Path) -> List[KernelConfig]: + """Scan generated headers and extract kernel configurations""" + + import re + + kernels = [] + + for header_file in generated_dir.glob("**/*.hpp"): + try: + content = header_file.read_text() + + # Extract kernel name + name_match = re.search( + r'constexpr const char\* KERNEL_NAME\s*=\s*"([^"]+)"', content + ) + if not name_match: + continue + + kernel_name = name_match.group(1) + + # Extract tile configuration (support ck_tile::index_t) + tile_m_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileM\s*=\s*(\d+)", + content, + ) + tile_n_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileN\s*=\s*(\d+)", + content, + ) + tile_k_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileK\s*=\s*(\d+)", + content, + ) + + tile_m = int(tile_m_match.group(1)) if tile_m_match else 256 + tile_n = int(tile_n_match.group(1)) if tile_n_match else 256 + tile_k = int(tile_k_match.group(1)) if tile_k_match else 32 + + # Extract warp configuration + warp_m_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_M\s*=\s*(\d+)", + content, + ) + warp_n_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_N\s*=\s*(\d+)", + content, + ) + warp_k_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_K\s*=\s*(\d+)", + content, + ) + + warp_m = int(warp_m_match.group(1)) if warp_m_match else 2 + warp_n = int(warp_n_match.group(1)) if warp_n_match else 2 + warp_k = int(warp_k_match.group(1)) if warp_k_match else 1 + + # Extract warp tile configuration + warp_tile_m_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileM\s*=\s*(\d+)", + content, + ) + warp_tile_n_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileN\s*=\s*(\d+)", + content, + ) + warp_tile_k_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileK\s*=\s*(\d+)", + content, + ) + + warp_tile_m = int(warp_tile_m_match.group(1)) if warp_tile_m_match else 32 + warp_tile_n = int(warp_tile_n_match.group(1)) if warp_tile_n_match else 32 + warp_tile_k = int(warp_tile_k_match.group(1)) if warp_tile_k_match else 16 + + # Extract other parameters (with defaults) + block_size_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+BlockSize\s*=\s*(\d+)", + content, + ) + block_size = int(block_size_match.group(1)) if block_size_match else 256 + + # Extract boolean flags + pad_m = re.search(r"kPadM\s*=\s*true", content) is not None + pad_n = re.search(r"kPadN\s*=\s*true", content) is not None + pad_k = re.search(r"kPadK\s*=\s*true", content) is not None + persistent = ( + re.search(r"UsePersistentKernel\s*=\s*true", content) is not None + ) + double_buffer = ( + re.search(r"DoubleSmemBuffer\s*=\s*true", content) is not None + ) + transpose_c = re.search(r"TransposeC\s*=\s*true", content) is not None + + kernel = KernelConfig( + name=kernel_name, + header_file=str(header_file.relative_to(generated_dir.parent)), + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + block_size=block_size, + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=pad_m, + pad_n=pad_n, + pad_k=pad_k, + persistent=persistent, + double_buffer=double_buffer, + transpose_c=transpose_c, + ) + + kernels.append(kernel) + + except Exception as e: + print(f"Warning: Failed to parse {header_file}: {e}") + continue + + return kernels + + +def main(): + parser = argparse.ArgumentParser( + description="Generate dispatcher registration code" + ) + parser.add_argument( + "--generated-dir", + type=str, + required=True, + help="Directory containing generated kernel headers", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory for registration code", + ) + parser.add_argument( + "--manifest", type=str, help="Optional manifest file with kernel configurations" + ) + parser.add_argument( + "--scan", + action="store_true", + help="Scan generated headers instead of using manifest", + ) + + args = parser.parse_args() + + generated_dir = Path(args.generated_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load kernel configurations + if args.manifest: + print(f"Loading kernels from manifest: {args.manifest}") + kernels = load_kernel_manifest(Path(args.manifest)) + elif args.scan: + print(f"Scanning generated headers in: {generated_dir}") + kernels = scan_generated_headers(generated_dir) + else: + print("Error: Must specify either --manifest or --scan") + return 1 + + print(f"Found {len(kernels)} kernels") + + # Generate registration code + registration_header = output_dir / "dispatcher_registration.hpp" + registration_cpp = output_dir / "dispatcher_registration.cpp" + + generate_registration_header(kernels, registration_header) + generate_registration_cpp(kernels, registration_cpp) + + # Generate manifest for Python + manifest_output = output_dir / "kernels_manifest.json" + manifest_data = { + "kernels": [ + { + "name": k.name, + "header_file": k.header_file, + "tile_m": k.tile_m, + "tile_n": k.tile_n, + "tile_k": k.tile_k, + "block_size": k.block_size, + "persistent": k.persistent, + } + for k in kernels + ] + } + + with open(manifest_output, "w") as f: + json.dump(manifest_data, f, indent=2) + + print(f"✓ Generated manifest: {manifest_output}") + print("\n✓ Registration code generation complete!") + print(f" Total kernels: {len(kernels)}") + print(" Output files:") + print(f" - {registration_header}") + print(f" - {registration_cpp}") + print(f" - {manifest_output}") + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/dispatcher/codegen/generate_kernel_wrappers.py b/dispatcher/codegen/generate_kernel_wrappers.py new file mode 100644 index 0000000000..53a9bff3ed --- /dev/null +++ b/dispatcher/codegen/generate_kernel_wrappers.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Generate one .cpp wrapper file per kernel header for maximum parallel compilation. + +Each kernel becomes its own translation unit, enabling: + - Maximum parallelism with make -j$(nproc) + - Per-kernel build progress (e.g., [5/128] Building kernel: gemm_fp16_128x128) + - Incremental rebuilds (only changed kernels recompile) + - Fine-grained build time analysis + +Usage: + python3 generate_kernel_wrappers.py --kernel-dir build/generated_kernels --output-dir build/kernel_wrappers + +Output structure: + build/kernel_wrappers/ + ├── gemm_fp16_rcr_128x128x32.cpp + ├── gemm_fp16_rcr_256x256x64.cpp + ├── conv_fwd_fp16_2d_128x128.cpp + └── ... + +Each .cpp simply includes its corresponding .hpp and forces symbol emission. +""" + +import argparse +import sys +from pathlib import Path +from typing import List, Tuple +import concurrent.futures + + +WRAPPER_TEMPLATE_GEMM = """// SPDX-License-Identifier: MIT +// Auto-generated wrapper for: {kernel_name} +// This file enables per-kernel parallel compilation + +#include "{kernel_hpp}" + +// Force symbol emission for kernel registration +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +// Marker to prevent dead code elimination +volatile bool _{kernel_id}_registered = true; + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + +WRAPPER_TEMPLATE_CONV = """// SPDX-License-Identifier: MIT +// Auto-generated wrapper for: {kernel_name} +// This file enables per-kernel parallel compilation + +#include "{kernel_hpp}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +volatile bool _{kernel_id}_registered = true; + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + + +def generate_wrapper( + kernel_hpp: Path, output_dir: Path, index: int, total: int +) -> Tuple[Path, bool]: + """Generate a .cpp wrapper for a single kernel header.""" + kernel_name = kernel_hpp.stem + kernel_id = kernel_name.replace("-", "_").replace(".", "_") + + # Select template based on kernel type + if kernel_name.startswith("gemm"): + template = WRAPPER_TEMPLATE_GEMM + else: + template = WRAPPER_TEMPLATE_CONV + + content = template.format( + kernel_name=kernel_name, + kernel_hpp=kernel_hpp.name, + kernel_id=kernel_id, + ) + + output_cpp = output_dir / f"{kernel_name}.cpp" + + # Only write if content changed (for incremental builds) + if output_cpp.exists(): + existing = output_cpp.read_text() + if existing == content: + return output_cpp, False # No change + + output_cpp.write_text(content) + return output_cpp, True # Written + + +def generate_cmake_list( + wrappers: List[Path], output_dir: Path, kernel_dir: Path +) -> Path: + """Generate CMakeLists.txt that compiles each wrapper as a separate object.""" + + num_kernels = len(wrappers) + + cmake_content = f'''# SPDX-License-Identifier: MIT +# Auto-generated CMakeLists.txt for per-kernel parallel compilation +# Generated {num_kernels} kernel translation units + +cmake_minimum_required(VERSION 3.16) + +# ============================================================================= +# Per-Kernel Object Targets ({num_kernels} kernels) +# ============================================================================= +# Each kernel is compiled as a separate OBJECT library for maximum parallelism. +# Build with: make -j$(nproc) all_kernels +# +# Progress output: +# [ 1/{num_kernels}] Building kernel: gemm_fp16_rcr_128x128x32 +# [ 2/{num_kernels}] Building kernel: gemm_fp16_rcr_256x256x64 +# ... + +set(KERNEL_INCLUDE_DIR "{kernel_dir}") +set(ALL_KERNEL_OBJECTS "") + +''' + + for idx, wrapper in enumerate(wrappers, 1): + kernel_name = wrapper.stem + obj_target = f"kobj_{kernel_name}" + + cmake_content += f""" +# [{idx}/{num_kernels}] {kernel_name} +add_library({obj_target} OBJECT {wrapper.name}) +target_include_directories({obj_target} PRIVATE ${{KERNEL_INCLUDE_DIR}} ${{CK_INCLUDE_DIR}}) +target_compile_options({obj_target} PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +set_target_properties({obj_target} PROPERTIES POSITION_INDEPENDENT_CODE ON) +if(hip_FOUND) + target_link_libraries({obj_target} PRIVATE hip::device hip::host) +endif() +list(APPEND ALL_KERNEL_OBJECTS $) +""" + + cmake_content += f""" + +# ============================================================================= +# Combined Kernel Library +# ============================================================================= +# Links all {num_kernels} kernel objects into a single shared library + +add_library(all_kernels SHARED ${{ALL_KERNEL_OBJECTS}}) +if(hip_FOUND) + target_link_libraries(all_kernels PRIVATE hip::device hip::host) +endif() +set_target_properties(all_kernels PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME "dispatcher_kernels" +) + +message(STATUS "Configured {num_kernels} kernel objects for parallel compilation") +message(STATUS "Build with: make -j$(nproc) all_kernels") +""" + + cmake_file = output_dir / "CMakeLists.txt" + cmake_file.write_text(cmake_content) + return cmake_file + + +def generate_ninja_build( + wrappers: List[Path], output_dir: Path, kernel_dir: Path +) -> Path: + """Generate build.ninja for even faster parallel compilation.""" + + num_kernels = len(wrappers) + + ninja_content = f"""# SPDX-License-Identifier: MIT +# Auto-generated build.ninja for per-kernel parallel compilation +# {num_kernels} kernel translation units + +# Variables +cxx = hipcc +cxxflags = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal --offload-compress +includes = -I{kernel_dir} -I/opt/rocm/include + +# Rules +rule compile + command = $cxx $cxxflags $includes -c $in -o $out + description = [{num_kernels}] Building kernel: $kernel_name + +rule link + command = $cxx -shared $in -o $out -L/opt/rocm/lib -lamdhip64 + description = Linking: $out + +# Kernel objects +""" + + obj_files = [] + for idx, wrapper in enumerate(wrappers, 1): + kernel_name = wrapper.stem + obj_file = f"{kernel_name}.o" + obj_files.append(obj_file) + + ninja_content += f""" +build {obj_file}: compile {wrapper.name} + kernel_name = {kernel_name} +""" + + ninja_content += f""" + +# Shared library +build libdispatcher_kernels.so: link {" ".join(obj_files)} + +# Default target +default libdispatcher_kernels.so +""" + + ninja_file = output_dir / "build.ninja" + ninja_file.write_text(ninja_content) + return ninja_file + + +def generate_makefile(wrappers: List[Path], output_dir: Path, kernel_dir: Path) -> Path: + """Generate Makefile for per-kernel parallel compilation.""" + + num_kernels = len(wrappers) + kernel_names = [w.stem for w in wrappers] + obj_files = [f"{name}.o" for name in kernel_names] + + makefile_content = f"""# SPDX-License-Identifier: MIT +# Auto-generated Makefile for per-kernel parallel compilation +# {num_kernels} kernel translation units +# +# Usage: +# make -j$(nproc) # Build all kernels in parallel +# make -j$(nproc) VERBOSE=1 # With per-kernel progress +# make clean # Remove all objects + +CXX = hipcc +CXXFLAGS = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 \\ + -Wno-undefined-func-template -Wno-float-equal --offload-compress +INCLUDES = -I{kernel_dir} -I/opt/rocm/include +LDFLAGS = -shared -L/opt/rocm/lib -lamdhip64 + +TARGET = libdispatcher_kernels.so +OBJECTS = {" ".join(obj_files)} + +# Progress counter (only works with make -j1, use ninja for parallel progress) +TOTAL_KERNELS = {num_kernels} +CURRENT = 0 + +.PHONY: all clean + +all: $(TARGET) +\t@echo "Built $(TARGET) with {num_kernels} kernels" + +$(TARGET): $(OBJECTS) +\t@echo "[LINK] Linking {num_kernels} kernel objects -> $@" +\t$(CXX) $(LDFLAGS) $^ -o $@ + +""" + + for idx, (wrapper, obj) in enumerate(zip(wrappers, obj_files), 1): + kernel_name = wrapper.stem + makefile_content += f""" +{obj}: {wrapper.name} +\t@echo "[{idx}/{num_kernels}] Building kernel: {kernel_name}" +\t$(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ +""" + + makefile_content += f""" + +clean: +\trm -f $(OBJECTS) $(TARGET) +\t@echo "Cleaned {num_kernels} kernel objects" +""" + + makefile = output_dir / "Makefile" + makefile.write_text(makefile_content) + return makefile + + +def main(): + parser = argparse.ArgumentParser( + description="Generate per-kernel wrapper .cpp files for parallel compilation" + ) + parser.add_argument( + "--kernel-dir", + type=Path, + required=True, + help="Directory containing generated kernel .hpp files", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for wrapper .cpp files", + ) + parser.add_argument( + "--pattern", + type=str, + default="*.hpp", + help="Glob pattern for kernel headers (default: *.hpp)", + ) + parser.add_argument( + "--generate-cmake", + action="store_true", + help="Generate CMakeLists.txt for the wrappers", + ) + parser.add_argument( + "--generate-ninja", + action="store_true", + help="Generate build.ninja for ninja builds", + ) + parser.add_argument( + "--generate-makefile", + action="store_true", + help="Generate Makefile for make builds", + ) + parser.add_argument( + "--parallel", + action="store_true", + default=True, + help="Generate wrappers in parallel (default: True)", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Verbose output", + ) + + args = parser.parse_args() + + # Find kernel headers + kernel_dir = args.kernel_dir.resolve() + if not kernel_dir.exists(): + print(f"Error: Kernel directory not found: {kernel_dir}", file=sys.stderr) + return 1 + + kernel_headers = sorted(kernel_dir.glob(args.pattern)) + if not kernel_headers: + print( + f"Error: No kernel headers found matching {args.pattern} in {kernel_dir}", + file=sys.stderr, + ) + return 1 + + num_kernels = len(kernel_headers) + print(f"Found {num_kernels} kernel headers in {kernel_dir}") + + # Create output directory + output_dir = args.output_dir.resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate wrappers + print(f"Generating {num_kernels} wrapper .cpp files...") + + wrappers = [] + written = 0 + + if args.parallel and num_kernels > 1: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = { + executor.submit( + generate_wrapper, hpp, output_dir, idx, num_kernels + ): hpp + for idx, hpp in enumerate(kernel_headers, 1) + } + for future in concurrent.futures.as_completed(futures): + wrapper_path, was_written = future.result() + wrappers.append(wrapper_path) + if was_written: + written += 1 + if args.verbose: + print(f" Generated: {wrapper_path.name}") + else: + for idx, hpp in enumerate(kernel_headers, 1): + wrapper_path, was_written = generate_wrapper( + hpp, output_dir, idx, num_kernels + ) + wrappers.append(wrapper_path) + if was_written: + written += 1 + if args.verbose: + print(f" [{idx}/{num_kernels}] Generated: {wrapper_path.name}") + + wrappers.sort(key=lambda p: p.name) + + print( + f" Total: {num_kernels} wrappers ({written} written, {num_kernels - written} unchanged)" + ) + + # Generate build files + if args.generate_cmake: + cmake_file = generate_cmake_list(wrappers, output_dir, kernel_dir) + print(f" Generated: {cmake_file}") + + if args.generate_ninja: + ninja_file = generate_ninja_build(wrappers, output_dir, kernel_dir) + print(f" Generated: {ninja_file}") + + if args.generate_makefile: + makefile = generate_makefile(wrappers, output_dir, kernel_dir) + print(f" Generated: {makefile}") + + print(f"\nOutput directory: {output_dir}") + print(f"Kernels ready for parallel compilation: {num_kernels}") + print("\nTo build:") + print(f" cd {output_dir}") + if args.generate_makefile: + print(" make -j$(nproc) # Parallel build with progress") + if args.generate_ninja: + print(" ninja # Fast parallel build") + if args.generate_cmake: + print(" cmake -B build && cmake --build build -j$(nproc)") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/codegen/kernel_config_loader.py b/dispatcher/codegen/kernel_config_loader.py new file mode 100644 index 0000000000..537fc40581 --- /dev/null +++ b/dispatcher/codegen/kernel_config_loader.py @@ -0,0 +1,798 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Kernel Configuration Loader + +Load kernel configurations from JSON files for generating specific kernel sets. +Compatible with tile_engine JSON format. + +Usage: + from kernel_config_loader import load_kernel_configs, KernelConfigSet + + # Load configs from JSON + config_set = load_kernel_configs("my_kernels.json") + + # Get all configurations (cartesian product of all parameter values) + for config in config_set.generate_configs(): + print(config) + + # Use with codegen + from unified_gemm_codegen import UnifiedGemmCodegen + codegen = UnifiedGemmCodegen(...) + codegen.generate_from_configs(config_set.generate_configs()) +""" + +import json +import itertools +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Dict, Any, Optional, Iterator + + +@dataclass +class TileConfig: + """Tile configuration for a kernel""" + + tile_m: int = 128 + tile_n: int = 128 + tile_k: int = 32 + warp_m: int = 2 + warp_n: int = 2 + warp_k: int = 1 + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + +@dataclass +class TraitConfig: + """Trait configuration for a kernel (order matches GEMM/Conv TraitConfig)""" + + pipeline: str = "compv4" + epilogue: str = "cshuffle" + scheduler: str = "intrawave" + pad_m: bool = False + pad_n: bool = False + pad_k: bool = False + + +@dataclass +class KernelConfig: + """Complete kernel configuration""" + + tile: TileConfig = field(default_factory=TileConfig) + trait: TraitConfig = field(default_factory=TraitConfig) + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout: str = "rcr" + gpu_target: str = "gfx942" + variant: str = "standard" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for codegen""" + return { + "tile_m": self.tile.tile_m, + "tile_n": self.tile.tile_n, + "tile_k": self.tile.tile_k, + "warp_m": self.tile.warp_m, + "warp_n": self.tile.warp_n, + "warp_k": self.tile.warp_k, + "warp_tile_m": self.tile.warp_tile_m, + "warp_tile_n": self.tile.warp_tile_n, + "warp_tile_k": self.tile.warp_tile_k, + "pipeline": self.trait.pipeline, + "scheduler": self.trait.scheduler, + "epilogue": self.trait.epilogue, + "pad_m": self.trait.pad_m, + "pad_n": self.trait.pad_n, + "pad_k": self.trait.pad_k, + "dtype_a": self.dtype_a, + "dtype_b": self.dtype_b, + "dtype_c": self.dtype_c, + "dtype_acc": self.dtype_acc, + "layout": self.layout, + "gpu_target": self.gpu_target, + "variant": self.variant, + } + + def kernel_name(self) -> str: + """Generate kernel name from config""" + name = f"gemm_{self.dtype_a}_{self.layout}_{self.trait.pipeline}" + name += f"_{self.trait.epilogue}_{self.trait.scheduler}" + name += f"_{str(self.trait.pad_m).capitalize()}" + name += f"_{str(self.trait.pad_n).capitalize()}" + name += f"_{str(self.trait.pad_k).capitalize()}" + name += "_False" # preshuffle + name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}" + name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}" + name += ( + f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}" + ) + return name + + +@dataclass +class KernelConfigSet: + """A set of kernel configurations loaded from JSON""" + + name: str = "default" + configs: List[KernelConfig] = field(default_factory=list) + + # Parameter ranges for generation + tile_m_values: List[int] = field(default_factory=lambda: [128]) + tile_n_values: List[int] = field(default_factory=lambda: [128]) + tile_k_values: List[int] = field(default_factory=lambda: [32]) + warp_m_values: List[int] = field(default_factory=lambda: [2]) + warp_n_values: List[int] = field(default_factory=lambda: [2]) + warp_k_values: List[int] = field(default_factory=lambda: [1]) + warp_tile_m_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_n_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_k_values: List[int] = field(default_factory=lambda: [16]) + + pipeline_values: List[str] = field(default_factory=lambda: ["compv4"]) + scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"]) + epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"]) + pad_m_values: List[bool] = field(default_factory=lambda: [False]) + pad_n_values: List[bool] = field(default_factory=lambda: [False]) + pad_k_values: List[bool] = field(default_factory=lambda: [False]) + + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout: str = "rcr" + gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"]) + variant: str = "standard" + + def generate_configs(self) -> Iterator[KernelConfig]: + """Generate all kernel configurations (cartesian product)""" + # Tile parameters + tile_params = itertools.product( + self.tile_m_values, + self.tile_n_values, + self.tile_k_values, + self.warp_m_values, + self.warp_n_values, + self.warp_k_values, + self.warp_tile_m_values, + self.warp_tile_n_values, + self.warp_tile_k_values, + ) + + # Trait parameters + trait_params = itertools.product( + self.pipeline_values, + self.scheduler_values, + self.epilogue_values, + self.pad_m_values, + self.pad_n_values, + self.pad_k_values, + ) + + # Convert to lists for reuse + tile_list = list(tile_params) + trait_list = list(trait_params) + + # Generate for each GPU target + for gpu_target in self.gpu_targets: + for tile in tile_list: + for trait in trait_list: + tile_cfg = TileConfig( + tile_m=tile[0], + tile_n=tile[1], + tile_k=tile[2], + warp_m=tile[3], + warp_n=tile[4], + warp_k=tile[5], + warp_tile_m=tile[6], + warp_tile_n=tile[7], + warp_tile_k=tile[8], + ) + trait_cfg = TraitConfig( + pipeline=trait[0], + scheduler=trait[1], + epilogue=trait[2], + pad_m=trait[3], + pad_n=trait[4], + pad_k=trait[5], + ) + yield KernelConfig( + tile=tile_cfg, + trait=trait_cfg, + dtype_a=self.dtype_a, + dtype_b=self.dtype_b, + dtype_c=self.dtype_c, + dtype_acc=self.dtype_acc, + layout=self.layout, + gpu_target=gpu_target, + variant=self.variant, + ) + + def config_count(self) -> int: + """Get total number of configurations""" + tile_count = ( + len(self.tile_m_values) + * len(self.tile_n_values) + * len(self.tile_k_values) + * len(self.warp_m_values) + * len(self.warp_n_values) + * len(self.warp_k_values) + * len(self.warp_tile_m_values) + * len(self.warp_tile_n_values) + * len(self.warp_tile_k_values) + ) + trait_count = ( + len(self.pipeline_values) + * len(self.scheduler_values) + * len(self.epilogue_values) + * len(self.pad_m_values) + * len(self.pad_n_values) + * len(self.pad_k_values) + ) + return tile_count * trait_count * len(self.gpu_targets) + + +def _get_values(config: Dict, key: str, default: List) -> List: + """Extract values from config dict, handling range specifications""" + if key not in config: + return default + + item = config[key] + + # Explicit values list + if "values" in item: + return item["values"] + + # Range specification (min, max, step) + if "min" in item and "max" in item: + min_val = item["min"] + max_val = item["max"] + step = item.get("step", 1) + return list(range(min_val, max_val + 1, step)) + + return default + + +def load_kernel_configs(json_path: str | Path) -> KernelConfigSet: + """ + Load kernel configurations from a JSON file. + + Supports both tile_engine format and dispatcher format. + + Args: + json_path: Path to JSON configuration file + + Returns: + KernelConfigSet with all parameter values loaded + """ + json_path = Path(json_path) + + with open(json_path) as f: + data = json.load(f) + + config_set = KernelConfigSet() + + # Name + config_set.name = data.get("kernel_set_name", json_path.stem) + + # Data types + if "datatype" in data: + dt = data["datatype"] + config_set.dtype_a = dt.get("a", "fp16") + config_set.dtype_b = dt.get("b", "fp16") + config_set.dtype_c = dt.get("c", "fp16") + config_set.dtype_acc = dt.get("acc", "fp32") + + # Layout + config_set.layout = data.get("layout", "rcr") + + # GPU targets + if "gpu_targets" in data: + config_set.gpu_targets = data["gpu_targets"] + elif "gpu_target" in data: + config_set.gpu_targets = [data["gpu_target"]] + + # Variant + config_set.variant = data.get("variant", "standard") + + # Tile config + tile_cfg = data.get("tile_config", {}) + config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128]) + config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128]) + config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32]) + config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2]) + config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2]) + config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1]) + config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32]) + config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32]) + config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16]) + + # Trait config + trait_cfg = data.get("trait_config", {}) + config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv4"]) + config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"]) + config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"]) + config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [False]) + config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [False]) + config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [False]) + + return config_set + + +# ============================================================================= +# Convolution Configuration Classes +# ============================================================================= + + +@dataclass +class ConvTileConfig: + """Tile configuration for a convolution kernel""" + + tile_m: int = 128 # M dimension (N * spatial_out for fwd) + tile_n: int = 128 # N dimension (K output channels for fwd) + tile_k: int = 32 # K dimension (C * filter for fwd) + warp_m: int = 2 + warp_n: int = 2 + warp_k: int = 1 + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + +@dataclass +class ConvTraitConfig: + """Trait configuration for a convolution kernel""" + + pipeline: str = "compv3" + scheduler: str = "intrawave" + epilogue: str = "cshuffle" + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + double_smem_buffer: bool = False + num_groups_to_merge: int = 1 + + +@dataclass +class ConvKernelConfig: + """Complete convolution kernel configuration""" + + tile: ConvTileConfig = field(default_factory=ConvTileConfig) + trait: ConvTraitConfig = field(default_factory=ConvTraitConfig) + dtype_input: str = "fp16" + dtype_weight: str = "fp16" + dtype_output: str = "fp16" + dtype_acc: str = "fp32" + variant: str = "forward" # forward, bwd_data, bwd_weight + ndim: int = 2 # 1, 2, or 3 + layout: str = "nhwgc" + gpu_target: str = "gfx942" + + # Vector sizes + vector_size_a: int = 4 + vector_size_b: int = 8 + vector_size_c: int = 8 + + # Occupancy + block_per_cu: int = 1 + num_wave_groups: int = 1 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for codegen""" + return { + "tile_m": self.tile.tile_m, + "tile_n": self.tile.tile_n, + "tile_k": self.tile.tile_k, + "warp_m": self.tile.warp_m, + "warp_n": self.tile.warp_n, + "warp_k": self.tile.warp_k, + "warp_tile_m": self.tile.warp_tile_m, + "warp_tile_n": self.tile.warp_tile_n, + "warp_tile_k": self.tile.warp_tile_k, + "pipeline": self.trait.pipeline, + "scheduler": self.trait.scheduler, + "epilogue": self.trait.epilogue, + "pad_m": self.trait.pad_m, + "pad_n": self.trait.pad_n, + "pad_k": self.trait.pad_k, + "double_smem_buffer": self.trait.double_smem_buffer, + "num_groups_to_merge": self.trait.num_groups_to_merge, + "dtype_input": self.dtype_input, + "dtype_weight": self.dtype_weight, + "dtype_output": self.dtype_output, + "dtype_acc": self.dtype_acc, + "variant": self.variant, + "ndim": self.ndim, + "layout": self.layout, + "gpu_target": self.gpu_target, + "vector_size_a": self.vector_size_a, + "vector_size_b": self.vector_size_b, + "vector_size_c": self.vector_size_c, + "block_per_cu": self.block_per_cu, + "num_wave_groups": self.num_wave_groups, + } + + def kernel_name(self) -> str: + """Generate kernel name from config""" + variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"} + var_str = variant_map.get(self.variant, self.variant) + + name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d" + name += f"_{self.trait.pipeline}_{self.trait.epilogue}_{self.trait.scheduler}" + name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}" + name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}" + name += ( + f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}" + ) + return name + + +@dataclass +class ConvKernelConfigSet: + """A set of convolution kernel configurations loaded from JSON""" + + name: str = "default" + configs: List[ConvKernelConfig] = field(default_factory=list) + + # Tile parameter ranges + tile_m_values: List[int] = field(default_factory=lambda: [128]) + tile_n_values: List[int] = field(default_factory=lambda: [128]) + tile_k_values: List[int] = field(default_factory=lambda: [32]) + warp_m_values: List[int] = field(default_factory=lambda: [2]) + warp_n_values: List[int] = field(default_factory=lambda: [2]) + warp_k_values: List[int] = field(default_factory=lambda: [1]) + warp_tile_m_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_n_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_k_values: List[int] = field(default_factory=lambda: [16]) + + # Trait parameter ranges + pipeline_values: List[str] = field(default_factory=lambda: ["compv3"]) + scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"]) + epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"]) + pad_m_values: List[bool] = field(default_factory=lambda: [True]) + pad_n_values: List[bool] = field(default_factory=lambda: [True]) + pad_k_values: List[bool] = field(default_factory=lambda: [True]) + double_smem_buffer_values: List[bool] = field(default_factory=lambda: [False]) + num_groups_to_merge_values: List[int] = field(default_factory=lambda: [1]) + + # Vector sizes + vector_size_a_values: List[int] = field(default_factory=lambda: [4]) + vector_size_b_values: List[int] = field(default_factory=lambda: [8]) + vector_size_c_values: List[int] = field(default_factory=lambda: [8]) + + # Occupancy + block_per_cu_values: List[int] = field(default_factory=lambda: [1]) + num_wave_groups_values: List[int] = field(default_factory=lambda: [1]) + + # Data types + dtype_input: str = "fp16" + dtype_weight: str = "fp16" + dtype_output: str = "fp16" + dtype_acc: str = "fp32" + + # Conv specific + variant: str = "forward" + ndim: int = 2 + layout: str = "nhwgc" + gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"]) + + def generate_configs(self) -> Iterator[ConvKernelConfig]: + """Generate all kernel configurations (cartesian product)""" + # Tile parameters + tile_params = itertools.product( + self.tile_m_values, + self.tile_n_values, + self.tile_k_values, + self.warp_m_values, + self.warp_n_values, + self.warp_k_values, + self.warp_tile_m_values, + self.warp_tile_n_values, + self.warp_tile_k_values, + ) + + # Trait parameters + trait_params = itertools.product( + self.pipeline_values, + self.scheduler_values, + self.epilogue_values, + self.pad_m_values, + self.pad_n_values, + self.pad_k_values, + self.double_smem_buffer_values, + self.num_groups_to_merge_values, + ) + + # Vector/occupancy parameters + extra_params = itertools.product( + self.vector_size_a_values, + self.vector_size_b_values, + self.vector_size_c_values, + self.block_per_cu_values, + self.num_wave_groups_values, + ) + + # Convert to lists for reuse + tile_list = list(tile_params) + trait_list = list(trait_params) + extra_list = list(extra_params) + + # Generate for each GPU target + for gpu_target in self.gpu_targets: + for tile in tile_list: + for trait in trait_list: + for extra in extra_list: + tile_cfg = ConvTileConfig( + tile_m=tile[0], + tile_n=tile[1], + tile_k=tile[2], + warp_m=tile[3], + warp_n=tile[4], + warp_k=tile[5], + warp_tile_m=tile[6], + warp_tile_n=tile[7], + warp_tile_k=tile[8], + ) + trait_cfg = ConvTraitConfig( + pipeline=trait[0], + scheduler=trait[1], + epilogue=trait[2], + pad_m=trait[3], + pad_n=trait[4], + pad_k=trait[5], + double_smem_buffer=trait[6], + num_groups_to_merge=trait[7], + ) + yield ConvKernelConfig( + tile=tile_cfg, + trait=trait_cfg, + dtype_input=self.dtype_input, + dtype_weight=self.dtype_weight, + dtype_output=self.dtype_output, + dtype_acc=self.dtype_acc, + variant=self.variant, + ndim=self.ndim, + layout=self.layout, + gpu_target=gpu_target, + vector_size_a=extra[0], + vector_size_b=extra[1], + vector_size_c=extra[2], + block_per_cu=extra[3], + num_wave_groups=extra[4], + ) + + def config_count(self) -> int: + """Get total number of configurations""" + tile_count = ( + len(self.tile_m_values) + * len(self.tile_n_values) + * len(self.tile_k_values) + * len(self.warp_m_values) + * len(self.warp_n_values) + * len(self.warp_k_values) + * len(self.warp_tile_m_values) + * len(self.warp_tile_n_values) + * len(self.warp_tile_k_values) + ) + trait_count = ( + len(self.pipeline_values) + * len(self.scheduler_values) + * len(self.epilogue_values) + * len(self.pad_m_values) + * len(self.pad_n_values) + * len(self.pad_k_values) + * len(self.double_smem_buffer_values) + * len(self.num_groups_to_merge_values) + ) + extra_count = ( + len(self.vector_size_a_values) + * len(self.vector_size_b_values) + * len(self.vector_size_c_values) + * len(self.block_per_cu_values) + * len(self.num_wave_groups_values) + ) + return tile_count * trait_count * extra_count * len(self.gpu_targets) + + +def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: + """ + Load convolution kernel configurations from a JSON file. + + Args: + json_path: Path to JSON configuration file + + Returns: + ConvKernelConfigSet with all parameter values loaded + """ + json_path = Path(json_path) + + with open(json_path) as f: + data = json.load(f) + + config_set = ConvKernelConfigSet() + + # Name + config_set.name = data.get("kernel_set_name", json_path.stem) + + # Data types + if "datatype" in data: + dt = data["datatype"] + config_set.dtype_input = dt.get("input", "fp16") + config_set.dtype_weight = dt.get("weight", "fp16") + config_set.dtype_output = dt.get("output", "fp16") + config_set.dtype_acc = dt.get("acc", "fp32") + + # Conv specific + config_set.variant = data.get("variant", "forward") + config_set.ndim = data.get("ndim", 2) + config_set.layout = data.get("layout", "nhwgc") + + # GPU targets + if "gpu_targets" in data: + config_set.gpu_targets = data["gpu_targets"] + elif "gpu_target" in data: + config_set.gpu_targets = [data["gpu_target"]] + + # Tile config + tile_cfg = data.get("tile_config", {}) + config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128]) + config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128]) + config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32]) + config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2]) + config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2]) + config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1]) + config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32]) + config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32]) + config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16]) + + # Trait config + trait_cfg = data.get("trait_config", {}) + config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv3"]) + config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"]) + config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"]) + config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [True]) + config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [True]) + config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [True]) + config_set.double_smem_buffer_values = _get_values( + trait_cfg, "double_smem_buffer", [False] + ) + config_set.num_groups_to_merge_values = _get_values( + trait_cfg, "num_groups_to_merge", [1] + ) + + # Vector config + vec_cfg = data.get("vector_config", {}) + config_set.vector_size_a_values = _get_values(vec_cfg, "vector_size_a", [4]) + config_set.vector_size_b_values = _get_values(vec_cfg, "vector_size_b", [8]) + config_set.vector_size_c_values = _get_values(vec_cfg, "vector_size_c", [8]) + + # Occupancy config + occ_cfg = data.get("occupancy_config", {}) + config_set.block_per_cu_values = _get_values(occ_cfg, "block_per_cu", [1]) + config_set.num_wave_groups_values = _get_values(occ_cfg, "num_wave_groups", [1]) + + return config_set + + +def generate_cpp_conv_kernel_set_declaration( + config_set: ConvKernelConfigSet, + set_name: Optional[str] = None, +) -> str: + """ + Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet. + """ + name = set_name or config_set.name + + lines = [f"DECL_CONV_KERNEL_SET({name},"] + + for config in config_set.generate_configs(): + line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, ' + line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})" + lines.append(line) + + lines.append(");") + + return "\n".join(lines) + + +# ============================================================================= +# GEMM Configuration Export Functions +# ============================================================================= + + +def generate_cpp_kernel_set_declaration( + config_set: KernelConfigSet, + set_name: Optional[str] = None, +) -> str: + """ + Generate C++ DECL_KERNEL_SET code from a KernelConfigSet. + + Args: + config_set: The kernel configuration set + set_name: Optional name override for the kernel set + + Returns: + C++ code string with DECL_KERNEL_SET declaration + """ + name = set_name or config_set.name + + lines = [f"DECL_KERNEL_SET({name},"] + + for config in config_set.generate_configs(): + # Generate .add() call for each config + line = f' .add("{config.dtype_a}", "{config.layout}", ' + line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})" + lines.append(line) + + lines.append(");") + + return "\n".join(lines) + + +# CLI for testing +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python kernel_config_loader.py ") + print("\nLoads kernel configurations from JSON and prints summary.") + sys.exit(1) + + json_path = sys.argv[1] + + try: + config_set = load_kernel_configs(json_path) + + print(f"Kernel Set: {config_set.name}") + print( + f"Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}, Acc={config_set.dtype_acc}" + ) + print(f"Layout: {config_set.layout}") + print(f"GPU Targets: {config_set.gpu_targets}") + print(f"Variant: {config_set.variant}") + print() + print("Tile Configurations:") + print(f" tile_m: {config_set.tile_m_values}") + print(f" tile_n: {config_set.tile_n_values}") + print(f" tile_k: {config_set.tile_k_values}") + print(f" warp_m: {config_set.warp_m_values}") + print(f" warp_n: {config_set.warp_n_values}") + print(f" warp_k: {config_set.warp_k_values}") + print( + f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}" + ) + print() + print("Trait Configurations:") + print(f" pipeline: {config_set.pipeline_values}") + print(f" scheduler: {config_set.scheduler_values}") + print(f" epilogue: {config_set.epilogue_values}") + print( + f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}" + ) + print() + print(f"Total configurations: {config_set.config_count()}") + print() + + # Print first few config names + print("Sample kernel names:") + for i, config in enumerate(config_set.generate_configs()): + if i >= 5: + print(f" ... and {config_set.config_count() - 5} more") + break + print(f" {config.kernel_name()}") + print() + + # Generate C++ code + if "--cpp" in sys.argv: + print("C++ Declaration:") + print("-" * 60) + print(generate_cpp_kernel_set_declaration(config_set)) + + except Exception as e: + print(f"Error: {e}") + sys.exit(1) diff --git a/dispatcher/codegen/preselected_kernels.py b/dispatcher/codegen/preselected_kernels.py new file mode 100644 index 0000000000..010d930639 --- /dev/null +++ b/dispatcher/codegen/preselected_kernels.py @@ -0,0 +1,518 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Preselected, Benchmarked Kernel Configurations + +Curated kernel sets optimized for different workload characteristics: +- Compute-friendly: Large tiles, high arithmetic intensity +- Memory-friendly: Smaller tiles, better memory access patterns +- Latency-friendly: Minimal tiles, low latency for small problems +""" + +from functools import partial, lru_cache +from typing import List +from unified_gemm_codegen import KernelConfig, TileConfig, TraitConfig, GemmVariant + + +# ============================================================================ +# Base Configurations +# ============================================================================ + + +def _base_fp16_rcr_compute() -> partial: + """Base configuration for compute-intensive FP16 RCR kernels""" + return partial( + KernelConfig, + tile=None, # Will be overridden + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +def _base_fp16_rcr_memory() -> partial: + """Base configuration for memory-intensive FP16 RCR kernels""" + # Note: Use 'mem' pipeline for interwave scheduler (compv3/compv4/compv5/compv6 only support intrawave) + return partial( + KernelConfig, + tile=None, # Will be overridden + trait=TraitConfig( + pipeline="mem", + epilogue="cshuffle", + scheduler="interwave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=128, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +def _base_fp16_rcr_latency() -> partial: + """Base configuration for latency-sensitive FP16 RCR kernels""" + return partial( + KernelConfig, + tile=None, # Will be overridden + trait=TraitConfig( + pipeline="mem", + epilogue="default", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=128, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +# ============================================================================ +# Preselected FP16 RCR Kernels +# ============================================================================ + + +@lru_cache(None) +def preselected_fp16_rcr_compute() -> List[KernelConfig]: + """ + Compute-friendly FP16 RCR kernels + + Optimized for: + - Large M, N dimensions (>= 128) + - High arithmetic intensity + - Good occupancy + - Maximum throughput + """ + base = _base_fp16_rcr_compute() + + return [ + # Large tiles for maximum compute + base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(256, 128, 32, 4, 2, 1, 32, 32, 16)), + base(tile=TileConfig(128, 256, 32, 2, 4, 1, 32, 32, 16)), + # Balanced tiles + base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)), + # With persistent kernel for large batches + base( + tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16), + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=False, + pad_n=False, + pad_k=False, + persistent=True, + ), + ), + ] + + +@lru_cache(None) +def preselected_fp16_rcr_memory() -> List[KernelConfig]: + """ + Memory-friendly FP16 RCR kernels + + Optimized for: + - Small to medium M, N dimensions + - Memory-bound workloads + - Better cache utilization + - Lower register pressure + """ + base = _base_fp16_rcr_memory() + + return [ + # Small tiles for memory efficiency + base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)), + base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)), + base(tile=TileConfig(16, 64, 32, 1, 2, 1, 16, 16, 16)), + base(tile=TileConfig(64, 16, 32, 2, 1, 1, 16, 16, 16)), + # Medium tiles + base(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)), + base(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)), + base(tile=TileConfig(32, 128, 32, 1, 2, 1, 32, 32, 16)), + base(tile=TileConfig(128, 32, 32, 2, 1, 1, 32, 32, 16)), + ] + + +@lru_cache(None) +def preselected_fp16_rcr_latency() -> List[KernelConfig]: + """ + Latency-friendly FP16 RCR kernels + + Optimized for: + - Very small M, N dimensions (< 64) + - Minimal launch overhead + - Low latency + - Quick execution + """ + base = _base_fp16_rcr_latency() + + return [ + # Minimal tiles for low latency + base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)), + base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)), + ] + + +# ============================================================================ +# Preselected Multi-D Kernels +# ============================================================================ + + +@lru_cache(None) +def preselected_fp16_rcr_multi_d() -> List[KernelConfig]: + """ + Multi-D GEMM kernels with element-wise fusion + + Common fusions: + - MultiDAdd: E = C + D0 + D1 + - Relu: E = max(C, 0) + - Gelu: E = gelu(C) + """ + base = _base_fp16_rcr_compute() + + configs = [] + + # Best-performing tile for fused operations + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + + # Common element-wise operations + for ew_op in ["MultiDAdd", "Relu", "Gelu", "FastGelu"]: + for num_d in [1, 2]: + configs.append( + base( + tile=tile, + variant=GemmVariant.MULTI_D, + elementwise_op=ew_op, + num_d_tensors=num_d, + ) + ) + + return configs + + +@lru_cache(None) +def preselected_fp16_rcr_preshuffle() -> List[KernelConfig]: + """ + Preshuffle GEMM kernels for weight optimization + + Best for: + - Repeated use of same weights + - Inference workloads + - Batch size > 1 + """ + base = _base_fp16_rcr_compute() + + return [ + base( + tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16), + variant=GemmVariant.PRESHUFFLE, + preshuffle=True, + ), + base( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + variant=GemmVariant.PRESHUFFLE, + preshuffle=True, + ), + ] + + +# ============================================================================ +# Unified Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_fp16_rcr_all() -> List[KernelConfig]: + """All preselected FP16 RCR kernels""" + return ( + preselected_fp16_rcr_compute() + + preselected_fp16_rcr_memory() + + preselected_fp16_rcr_latency() + + preselected_fp16_rcr_multi_d() + + preselected_fp16_rcr_preshuffle() + ) + + +@lru_cache(None) +def preselected_fp16_rcr_essential() -> List[KernelConfig]: + """ + Essential FP16 RCR kernels - minimal set for most workloads + + Covers: + - 90% of common GEMM sizes + - Key fusion operations + - Balanced performance + """ + base_compute = _base_fp16_rcr_compute() + base_memory = _base_fp16_rcr_memory() + + return [ + # Top compute kernels + base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + # Top memory kernels + base_memory(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)), + base_memory(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)), + # Essential fusions + base_compute( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + variant=GemmVariant.MULTI_D, + elementwise_op="Relu", + num_d_tensors=1, + ), + base_compute( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + variant=GemmVariant.MULTI_D, + elementwise_op="Gelu", + num_d_tensors=1, + ), + ] + + +# ============================================================================ +# Default Fallback +# ============================================================================ + + +def default_kernel() -> KernelConfig: + """ + Default fallback kernel - guaranteed to work + + Known-good configuration tested on gfx942 + """ + return KernelConfig( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +# ============================================================================ +# BF16 Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_bf16_rcr_essential() -> List[KernelConfig]: + """Essential BF16 RCR kernels""" + base_compute = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# INT8 Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_int8_rcr_essential() -> List[KernelConfig]: + """Essential INT8 RCR kernels for quantized inference""" + base = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# FP8 Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_fp8_rcr_essential() -> List[KernelConfig]: + """Essential FP8 RCR kernels for AI training""" + base = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# Mixed Precision Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_mixed_precision() -> List[KernelConfig]: + """Mixed-precision kernels (FP16 inputs, FP32 output)""" + base = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# Registry +# ============================================================================ + +PRESELECTED_SETS = { + # FP16 sets + "fp16_rcr_compute": preselected_fp16_rcr_compute, + "fp16_rcr_memory": preselected_fp16_rcr_memory, + "fp16_rcr_latency": preselected_fp16_rcr_latency, + "fp16_rcr_multi_d": preselected_fp16_rcr_multi_d, + "fp16_rcr_preshuffle": preselected_fp16_rcr_preshuffle, + "fp16_rcr_all": preselected_fp16_rcr_all, + "fp16_rcr_essential": preselected_fp16_rcr_essential, + # BF16 sets + "bf16_rcr_essential": preselected_bf16_rcr_essential, + # INT8 sets + "int8_rcr_essential": preselected_int8_rcr_essential, + # FP8 sets + "fp8_rcr_essential": preselected_fp8_rcr_essential, + # Mixed precision + "mixed_precision": preselected_mixed_precision, +} + + +def get_preselected_set(name: str) -> List[KernelConfig]: + """Get a preselected kernel set by name""" + if name not in PRESELECTED_SETS: + raise ValueError( + f"Unknown preselected set: {name}. Available: {list(PRESELECTED_SETS.keys())}" + ) + return PRESELECTED_SETS[name]() + + +def list_preselected_sets() -> List[str]: + """List all available preselected sets""" + return list(PRESELECTED_SETS.keys()) + + +# ============================================================================ +# CLI for testing +# ============================================================================ + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="List preselected kernel configurations" + ) + parser.add_argument( + "--set", + type=str, + default="fp16_rcr_essential", + choices=list_preselected_sets(), + help="Preselected set to display", + ) + parser.add_argument("--count-only", action="store_true", help="Only show count") + + args = parser.parse_args() + + configs = get_preselected_set(args.set) + + if args.count_only: + print(f"{args.set}: {len(configs)} kernels") + else: + print(f"Preselected set: {args.set}") + print(f"Total kernels: {len(configs)}\n") + for i, cfg in enumerate(configs, 1): + print(f"{i}. {cfg.variant.value}") + print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}") + print(f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}") + if cfg.variant == GemmVariant.MULTI_D: + print( + f" Element-wise: {cfg.elementwise_op}, D tensors: {cfg.num_d_tensors}" + ) + print() diff --git a/dispatcher/codegen/unified_gemm_codegen.py b/dispatcher/codegen/unified_gemm_codegen.py new file mode 100755 index 0000000000..b0dd961be7 --- /dev/null +++ b/dispatcher/codegen/unified_gemm_codegen.py @@ -0,0 +1,1713 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Unified GEMM Code Generator - Single Source of Truth + +This is THE unified code generator for all GEMM kernel variants: +- Standard GEMM (C = A × B) +- Preshuffle GEMM (optimized weight access) +- Multi-D GEMM (element-wise fusion) + +Generates both CK Tile kernels AND dispatcher wrappers in one pass. +Replaces all tile_engine GEMM codegen. +""" + +import json +import argparse +import itertools +import logging +from pathlib import Path +from typing import Dict, List, Tuple, Optional +from dataclasses import dataclass, asdict +from enum import Enum +import concurrent.futures + +# Import architecture filter for GPU-specific validation +try: + from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType + + HAS_ARCH_FILTER = True +except ImportError: + HAS_ARCH_FILTER = False + ArchFilter = None + ArchKernelConfig = None + OperatorType = None + + +# ============================================================================= +# Preshuffle Validation (copied from tile_engine/ops/commons/gemm_validation_utils.py) +# ============================================================================= + +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, +} + + +def _validate_preshuffle_vector_load( + warp_tile_m: int, + warp_tile_k: int, + datatype: str, + m_iter_per_warp: float, + wave_size: int = 64, + vector_load_size: int = 16, +) -> bool: + """ + Validate vector load alignment for preshuffle pipeline. + + Checks: (warp_tile_m * warp_tile_k * elem_size * m_iter_per_warp / wave_size) % vector_load_size == 0 + """ + elem_size = ELEMENT_SIZE_MAP.get(datatype, 2) + access_size = (warp_tile_m * warp_tile_k * elem_size * m_iter_per_warp) / wave_size + return access_size % vector_load_size == 0 + + +def _validate_preshuffle_m0_m1_m2( + tile_m: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + datatype: str, + vector_load_size: int = 16, + warp_size: int = 64, +) -> bool: + """ + Validate M0, M1, M2 configuration for preshuffle matrix A row-major layout. + Ensures proper memory access pattern alignment. + """ + try: + elem_size = ELEMENT_SIZE_MAP.get(datatype, 2) + MPerBlock = tile_m + + # Calculate K1 + K1 = vector_load_size / elem_size + if K1 != int(K1): + return False + K1 = int(K1) + + # Calculate K0 + if tile_k % K1 != 0: + return False + K0 = tile_k // K1 + + # Calculate M2 + if warp_size % K0 != 0: + return False + M2 = warp_size // K0 + + # Calculate number of warps + NumWarps = warp_m * warp_n * warp_k + M0 = NumWarps + + # Calculate M1 + if (M2 * M0) == 0: + return False + if MPerBlock % (M2 * M0) != 0: + return False + M1 = MPerBlock // (M2 * M0) + + # Validate: M0 * M1 * M2 == MPerBlock + return (M0 * M1 * M2) == MPerBlock + + except (ZeroDivisionError, ValueError): + return False + + +def is_preshuffle_config_valid( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + datatype: str, +) -> bool: + """ + Comprehensive preshuffle configuration validation. + Copied from tile_engine/ops/commons/gemm_validation_utils.py + """ + # Basic divisibility checks + if tile_m % (warp_m * warp_tile_m) != 0: + return False + if tile_n % (warp_n * warp_tile_n) != 0: + return False + if tile_k % (warp_k * warp_tile_k) != 0: + return False + + # Calculate m_iter_per_warp + m_iter_per_warp = tile_m / (warp_m * warp_tile_m) + + # Validate vector load alignment + if not _validate_preshuffle_vector_load( + warp_tile_m, + warp_tile_k, + datatype, + m_iter_per_warp, + wave_size=64, + vector_load_size=16, + ): + return False + + # Validate M0/M1/M2 configuration + if not _validate_preshuffle_m0_m1_m2( + tile_m, + tile_k, + warp_m, + warp_n, + warp_k, + datatype, + vector_load_size=16, + warp_size=64, + ): + return False + + return True + + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + +log = logging.getLogger(__name__) + + +# ============================================================================ +# Configuration and Data Structures +# ============================================================================ + + +class GemmVariant(Enum): + """GEMM kernel variants""" + + STANDARD = "standard" + PRESHUFFLE = "preshuffle" + MULTI_D = "multi_d" + + +@dataclass +class TileConfig: + """Tile configuration parameters""" + + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + def is_valid(self) -> bool: + """Validate tile configuration""" + return ( + self.tile_m % (self.warp_m * self.warp_tile_m) == 0 + and self.tile_n % (self.warp_n * self.warp_tile_n) == 0 + and self.tile_k % (self.warp_k * self.warp_tile_k) == 0 + and self.tile_m > 0 + and self.tile_n > 0 + and self.tile_k > 0 + ) + + +@dataclass +class TraitConfig: + """Kernel trait configuration""" + + pipeline: str # mem, compv3, compv4 + epilogue: str # default, cshuffle + scheduler: str # intrawave, interwave + pad_m: bool + pad_n: bool + pad_k: bool + persistent: bool + + def is_valid(self) -> bool: + """Check if trait combination is valid""" + # Unsupported combinations + # Only 'mem' pipeline supports interwave scheduler. + # All compute pipelines (compv3/v4/v5/v6/async) only support intrawave. + unsupported = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + ("compv5", "cshuffle", "interwave"), + ("compv5", "default", "interwave"), + ("compv6", "cshuffle", "interwave"), + ("compv6", "default", "interwave"), + ("comp_async", "cshuffle", "interwave"), + ("comp_async", "default", "interwave"), + } + return (self.pipeline, self.epilogue, self.scheduler) not in unsupported + + +@dataclass +class KernelConfig: + """Complete kernel configuration""" + + tile: TileConfig + trait: TraitConfig + variant: GemmVariant = GemmVariant.STANDARD + + # Variant-specific + preshuffle: bool = False + elementwise_op: str = "PassThrough" + num_d_tensors: int = 0 + d_layout: str = "r" # Layout for D tensors (r=row, c=col) - same for all D tensors + + # Fixed parameters + block_size: int = 256 + k_block_per_cu: int = 1 + num_wave_groups: int = 1 + + def name(self, datatype: str, layout: str) -> str: + """C++ alias for template instance""" + return f"ck_tile_gemm_{self.key_name(datatype, layout)}" + + def key_name(self, datatype: str, layout: str) -> str: + """ + Unique identifier for this kernel configuration. + + All parameters that affect kernel behavior MUST be included to ensure + unique names for unique configurations: + - Data type and layout (signature) + - Tile, warp, warp_tile dimensions (algorithm) + - Pipeline, epilogue, scheduler (traits) + - Padding flags (affects divisibility requirements) + - Persistent mode + - Preshuffle variant + - Multi-D: elementwise op, num D tensors, D layout + - Occupancy: wave groups, k_block_per_cu (if non-default) + """ + parts = [] + # Signature + parts.append(f"dt_{datatype}") + parts.append(f"ly_{layout}") + + # Tile configuration + parts.append(f"tile_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}") + parts.append(f"warp_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}") + parts.append( + f"wtile_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}" + ) + + # Traits + parts.append(f"pipe_{self.trait.pipeline}") + parts.append(f"epi_{self.trait.epilogue}") + parts.append(f"sched_{self.trait.scheduler}") + + # Padding flags (only if not all True - the common case) + if not (self.trait.pad_m and self.trait.pad_n and self.trait.pad_k): + parts.append( + f"pad{int(self.trait.pad_m)}{int(self.trait.pad_n)}{int(self.trait.pad_k)}" + ) + + # Persistent mode + if self.trait.persistent: + parts.append("persist") + + # Preshuffle variant + if self.preshuffle: + parts.append("preshuffle") + + # Multi-D variant: include elementwise op, num tensors, and D layout + if self.variant == GemmVariant.MULTI_D: + parts.append(f"ew_{self.elementwise_op}") + parts.append(f"nd{self.num_d_tensors}") + parts.append(f"dly_{self.d_layout}") + + # Occupancy parameters (only if non-default) + if self.num_wave_groups != 1: + parts.append(f"wg{self.num_wave_groups}") + if self.k_block_per_cu != 1: + parts.append(f"kbpc{self.k_block_per_cu}") + + return "_".join(parts) + + def dict_items(self): + """Iterator over (field, value) pairs""" + return asdict(self).items() + + +# ============================================================================ +# Type Mappings +# ============================================================================ + + +class TypeMappings: + """Centralized type mappings for code generation""" + + DTYPE_TO_CK = { + "fp16": "fp16_t", + "bf16": "bf16_t", + "fp32": "float", + "fp8": "fp8_t", + "bf8": "bf8_t", + "int8": "int8_t", + } + + # Fully-qualified types for use outside of 'using namespace ck_tile' scope + DTYPE_TO_CK_QUALIFIED = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", # Built-in type, no namespace + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int8": "int8_t", # Built-in type + } + + DTYPE_TO_DISPATCHER = { + "fp16": "DataType::FP16", + "bf16": "DataType::BF16", + "fp32": "DataType::FP32", + "fp8": "DataType::FP8", + "bf8": "DataType::BF8", + "int8": "DataType::INT8", + } + + LAYOUT_TO_CK = { + "r": "tensor_layout::gemm::RowMajor", + "c": "tensor_layout::gemm::ColumnMajor", + } + + LAYOUT_TO_DISPATCHER = { + "r": "LayoutTag::RowMajor", + "c": "LayoutTag::ColMajor", + } + + PIPELINE_TO_CK = { + "mem": "GemmPipelineAgBgCrMem", + "compv3": "GemmPipelineAgBgCrCompV3", + "compv4": "GemmPipelineAgBgCrCompV4", + "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_BASE = { + "mem": "BaseGemmPipelineAgBgCrMem", + "compv3": "BaseGemmPipelineAgBgCrCompV3", + "compv4": "BaseGemmPipelineAgBgCrCompV4", + "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_DISPATCHER = { + "mem": "Pipeline::Mem", + "compv3": "Pipeline::CompV3", + "compv4": "Pipeline::CompV4", + "preshufflev2": "Pipeline::PreShuffleV2", + } + + SCHEDULER_TO_CK = { + "intrawave": "GemmPipelineScheduler::Intrawave", + "interwave": "GemmPipelineScheduler::Interwave", + "default": "GemmPipelineScheduler::Default", + } + + SCHEDULER_TO_DISPATCHER = { + "intrawave": "Scheduler::Intrawave", + "interwave": "Scheduler::Interwave", + "default": "Scheduler::Auto", + } + + EPILOGUE_TO_DISPATCHER = { + "cshuffle": "Epilogue::CShuffle", + "default": "Epilogue::Default", + } + + @staticmethod + def get_output_dtype(dtype: str) -> str: + """Get output datatype (fp8/bf8 -> fp16)""" + return "fp16" if dtype in ["fp8", "bf8"] else dtype + + +# ============================================================================ +# Kernel Name Generator +# ============================================================================ + + +class KernelNaming: + """Unified kernel naming""" + + @staticmethod + def generate(config: KernelConfig, datatype: str, layout: str) -> str: + """Generate kernel name following tile_engine convention""" + t = config.tile + tr = config.trait + + # For multi-d, use 4-char layout (abcd), otherwise use 3-char layout (abc) + if config.variant == GemmVariant.MULTI_D: + full_layout = layout + config.d_layout # e.g., "rcr" + "r" = "rcrr" + else: + full_layout = layout + + name = ( + f"gemm_{datatype}_{full_layout}_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}" + ) + name += f"_{str(tr.pad_m).capitalize()}_{str(tr.pad_n).capitalize()}" + name += f"_{str(tr.pad_k).capitalize()}_{str(tr.persistent).capitalize()}" + name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}" + name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}" + name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}" + + # Add variant suffix + if config.variant == GemmVariant.PRESHUFFLE: + name += "_preshuffle" + elif config.variant == GemmVariant.MULTI_D: + name += f"_multid_{config.elementwise_op}_d{config.num_d_tensors}" + + return name + + +# ============================================================================ +# CK Tile Kernel Generator +# ============================================================================ + + +class CKTileKernelGenerator: + """Generates CK Tile kernel instance code""" + + def __init__(self, datatype: str, layout: str): + self.datatype = datatype + self.layout = layout + self.tm = TypeMappings() + + def generate(self, config: KernelConfig) -> str: + """Generate complete CK Tile kernel""" + kernel_name = KernelNaming.generate(config, self.datatype, self.layout) + + return f"""{self._header(kernel_name, config)} +{self._types(config, kernel_name)} +{self._selected_kernel_struct(config, kernel_name)} +""" + + def _header(self, kernel_name: str, config: KernelConfig) -> str: + """Generate header includes""" + includes = """// SPDX-License-Identifier: MIT +// Auto-generated CK Tile GEMM kernel +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" + +""" + + if config.variant == GemmVariant.MULTI_D: + includes += """ +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" +""" + + if config.preshuffle: + includes += """ +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" +""" + + return includes + + def _types(self, config: KernelConfig, kernel_name: str) -> str: + """Generate type definitions - just the namespace import, types are in kernel namespace""" + # Note: Data types and layouts are now defined inside each kernel's unique namespace + # to avoid type alias redefinition conflicts when mixing layouts (e.g., RCR + RRR) + types = """ +// Use ck_tile namespace for generated code +using namespace ck_tile; +""" + return types + + def _kernel_local_types(self, config: KernelConfig) -> str: + """Generate data type and layout definitions inside kernel namespace""" + output_dtype = self.tm.get_output_dtype(self.datatype) + + return f""" + // Data types (inside namespace to avoid conflicts across layouts) + using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using AccDataType = float; + using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; + + // Layouts (inside namespace to avoid conflicts when mixing layouts) + using ALayout = {self.tm.LAYOUT_TO_CK[self.layout[0]]}; + using BLayout = {self.tm.LAYOUT_TO_CK[self.layout[1]]}; + using CLayout = {self.tm.LAYOUT_TO_CK[self.layout[2]]}; +""" + + def _multi_d_types(self, config: KernelConfig) -> str: + """Generate multi-d type definitions (inside namespace to avoid conflicts)""" + if config.variant != GemmVariant.MULTI_D: + return "" + + d_types = ", ".join(["CDataType"] * config.num_d_tensors) + d_layout_ck = self.tm.LAYOUT_TO_CK[config.d_layout] + d_layouts = ", ".join([d_layout_ck] * config.num_d_tensors) + + return f""" +// Multi-D types (defined in namespace to avoid conflicts) +using DsDataType = tuple<{d_types}>; +using DLayout = {d_layout_ck}; // D tensor layout (can differ from C) +using DsLayout = tuple<{d_layouts}>; +using ElementWiseFn = element_wise::{config.elementwise_op}; +static constexpr index_t NumDTensor = {config.num_d_tensors}; +using GemmMultiDArgs = GemmMultiDHostArgs; +""" + + def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str: + """Generate SelectedKernel struct with unique name in unique namespace""" + t = config.tile + tr = config.trait + output_dtype = self.tm.get_output_dtype(self.datatype) + + # Generate unique struct name and namespace from kernel name + struct_name = f"Kernel_{kernel_name}" + # Create valid C++ namespace name (replace invalid chars) + ns_name = "ns_" + kernel_name.replace("-", "_") + + multi_d_types = self._multi_d_types(config) + + return f""" +namespace {ns_name} {{ +constexpr const char* KERNEL_NAME = "{kernel_name}"; + +// Data types (inside namespace to avoid conflicts across different kernels) +using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; +using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; +using AccDataType = float; +using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; + +// Layouts (inside namespace to avoid conflicts when mixing layouts like RCR + RRR) +using ALayout = {self.tm.LAYOUT_TO_CK[self.layout[0]]}; +using BLayout = {self.tm.LAYOUT_TO_CK[self.layout[1]]}; +using CLayout = {self.tm.LAYOUT_TO_CK[self.layout[2]]}; +{multi_d_types} +struct {struct_name} {{ + // Data types (required by backend as member types) + using ADataType = {ns_name}::ADataType; + using BDataType = {ns_name}::BDataType; + using CDataType = {ns_name}::CDataType; + using AccDataType = {ns_name}::AccDataType; + + // Configuration + static constexpr index_t BlockSize = {config.block_size}; + static constexpr index_t TileM = {t.tile_m}; + static constexpr index_t TileN = {t.tile_n}; + static constexpr index_t TileK = {t.tile_k}; + static constexpr index_t WarpPerBlock_M = {t.warp_m}; + static constexpr index_t WarpPerBlock_N = {t.warp_n}; + static constexpr index_t WarpPerBlock_K = {t.warp_k}; + static constexpr index_t WarpTileM = {t.warp_tile_m}; + static constexpr index_t WarpTileN = {t.warp_tile_n}; + static constexpr index_t WarpTileK = {t.warp_tile_k}; + + // Traits + static constexpr bool kPadM = {str(tr.pad_m).lower()}; + static constexpr bool kPadN = {str(tr.pad_n).lower()}; + static constexpr bool kPadK = {str(tr.pad_k).lower()}; + static constexpr bool TransposeC = false; + static constexpr bool UsePersistentKernel = {str(tr.persistent).lower()}; + static constexpr bool DoubleSmemBuffer = {str(tr.pipeline == "compv4" or tr.pipeline == "preshufflev2").lower()}; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Preshuffle = {str(config.preshuffle).lower()}; + static constexpr index_t NumWaveGroups = {config.num_wave_groups}; + + {self._tile_types(config, ns_name)} + {self._launch_function(config)} +}}; + +// Alias for tile_engine style compatibility (when used with -include) +using SelectedKernel = {struct_name}; +using SelectedKernelLauncher = {struct_name}; +}} // namespace {ns_name} + +// Export to global namespace ONLY for single-kernel includes +// Define CK_TILE_SINGLE_KERNEL_INCLUDE before including this header to enable these aliases +#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE +using {struct_name} = {ns_name}::{struct_name}; +using SelectedKernel = {ns_name}::{struct_name}; +constexpr const char* KERNEL_NAME = {ns_name}::KERNEL_NAME; +using ADataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; +using BDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; +using CDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.tm.get_output_dtype(self.datatype)]}; +using AccDataType = float; +#endif // CK_TILE_SINGLE_KERNEL_INCLUDE +""" + + def _tile_types(self, config: KernelConfig, ns_name: str) -> str: + """Generate tile type definitions - uses namespace-qualified types""" + return ( + f"""// Tile shape + using TileShape = TileGemmShape< + sequence, + sequence, + sequence, + false, false>; + + using TilePartitioner = GemmSpatiallyLocalTilePartitioner; + using Traits = TileGemmTraits; + using GemmPipelineProblem = GemmPipelineProblem; + using BaseGemmPipeline = """ + + self.tm.PIPELINE_TO_BASE[config.trait.pipeline] + + """;""" + ) + + def _launch_function(self, config: KernelConfig) -> str: + """Generate launch function""" + if config.variant == GemmVariant.MULTI_D: + return self._launch_function_multi_d(config) + if config.preshuffle: + return self._launch_function_preshuffle(config) + return self._launch_function_standard(config) + + def _launch_function_standard(self, config: KernelConfig) -> str: + """Generate launch function for standard GEMM""" + return f""" + static float launch(const GemmHostArgs& args, const stream_config& stream) {{ + const index_t k_grain = args.k_batch * TileK; + const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = {self.tm.SCHEDULER_TO_CK[config.trait.scheduler]}; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + ADataType, BDataType, AccDataType, TileShape, + TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = {self.tm.PIPELINE_TO_CK[config.trait.pipeline]}; + {self._epilogue_code(config)} + + using GemmKernel = ck_tile::GemmKernel; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported!"); + }} + + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if config.trait.persistent else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + constexpr int kBlockPerCu = {config.k_block_per_cu}; + ave_time = launch_kernel(stream, + make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }}""" + + def _launch_function_preshuffle(self, config: KernelConfig) -> str: + """Generate launch function for preshuffle GEMM (weight preshuffle variant) + + Preshuffle uses WeightPreshufflePipelineAGmemBGmemCRegV2 which has a different + API than standard pipelines. It's designed for weight-preshuffled GEMM operations. + """ + return f""" + static float launch(const GemmHostArgs& args, const stream_config& stream) {{ + const index_t k_grain = args.k_batch * TileK; + const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = GemmPipelineScheduler::Default; // Preshuffle uses Default scheduler + + // Preshuffle uses TileFlatmmShape instead of TileGemmShape for the problem + using UniversalGemmProblem = UniversalGemmPipelineProblem< + ADataType, BDataType, AccDataType, TileShape, + TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = WeightPreshufflePipelineAGmemBGmemCRegV2; + {self._epilogue_code(config)} + + using GemmKernel = ck_tile::GemmKernel; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported for preshuffle kernel!"); + }} + + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if config.trait.persistent else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + constexpr int kBlockPerCu = {config.k_block_per_cu}; + ave_time = launch_kernel(stream, + make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }}""" + + def _launch_function_multi_d(self, config: KernelConfig) -> str: + """Generate launch function for Multi-D GEMM""" + return f""" + // Multi-D launch function - takes GemmMultiDHostArgs with D tensor pointers + static float launch(const GemmMultiDArgs& args, const stream_config& stream) {{ + const index_t k_grain = args.k_batch * TileK; + const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = {self.tm.SCHEDULER_TO_CK[config.trait.scheduler]}; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + ADataType, BDataType, AccDataType, TileShape, + TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = {self.tm.PIPELINE_TO_CK[config.trait.pipeline]}; + {self._epilogue_code(config)} + + // Use GemmKernelMultiD for Multi-D variant + using GemmKernel = ck_tile::GemmKernelMultiD; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported! Multi-D currently doesn't support k_batch > 1"); + }} + + const dim3 grids = GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); + + constexpr int kBlockPerCu = {config.k_block_per_cu}; + ave_time = launch_kernel(stream, + make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }} + + // Overload for standard GemmHostArgs (converts to Multi-D args with empty D tensors) + static float launch(const GemmHostArgs& args, const stream_config& stream) {{ + std::array empty_ds{{}}; + std::array empty_strides{{}}; + for (index_t i = 0; i < NumDTensor; ++i) {{ + empty_ds[i] = nullptr; + empty_strides[i] = 0; + }} + GemmMultiDArgs multi_d_args{{ + args.a_ptr, + args.b_ptr, + empty_ds, + args.e_ptr, + args.k_batch, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + empty_strides, + args.stride_C + }}; + return launch(multi_d_args, stream); + }}""" + + def _epilogue_code(self, config: KernelConfig) -> str: + """Generate epilogue code""" + if config.variant == GemmVariant.MULTI_D: + return """ + using EpilogueProblem = CShuffleEpilogueProblem< + ADataType, BDataType, DsDataType, AccDataType, CDataType, + DsLayout, CLayout, ElementWiseFn, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, + TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>; + using GemmEpilogue = CShuffleEpilogue;""" + elif config.trait.epilogue == "cshuffle": + return """ + using EpilogueProblem = CShuffleEpilogueProblem< + ADataType, BDataType, tuple<>, AccDataType, CDataType, + tuple<>, CLayout, element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, + TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>; + using GemmEpilogue = CShuffleEpilogue;""" + else: + return """ + using EpilogueProblem = DefaultGemm2DEpilogueProblem< + ADataType, BDataType, tuple<>, AccDataType, CDataType, + tuple<>, CLayout, element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + kPadM, kPadN, WarpTileM, WarpTileN, WarpTileK, TransposeC>; + using GemmEpilogue = DefaultGemm2DEpilogue;""" + + +# ============================================================================ +# Dispatcher Wrapper Generator +# ============================================================================ + + +class DispatcherWrapperGenerator: + """Generates dispatcher wrapper code""" + + def __init__(self, datatype: str, layout: str): + self.datatype = datatype + self.layout = layout + self.tm = TypeMappings() + + def generate( + self, config: KernelConfig, kernel_path: Path, output_dir: Path + ) -> str: + """Generate dispatcher wrapper""" + kernel_name = KernelNaming.generate(config, self.datatype, self.layout) + output_dtype = self.tm.get_output_dtype(self.datatype) + rel_path = kernel_path.relative_to(output_dir) + + return f"""// SPDX-License-Identifier: MIT +// Auto-generated dispatcher wrapper +#pragma once + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp" +#include "{rel_path}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +using ::ck_tile::dispatcher::KernelInstancePtr; +using ::ck_tile::dispatcher::KernelKey; +using ::ck_tile::dispatcher::DataType; +using ::ck_tile::dispatcher::LayoutTag; +using ::ck_tile::dispatcher::Pipeline; +using ::ck_tile::dispatcher::Scheduler; +using ::ck_tile::dispatcher::Epilogue; +using Priority = ::ck_tile::dispatcher::Registry::Priority; +namespace backends = ::ck_tile::dispatcher::backends; + +inline KernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{ + // Use the unique kernel struct name + using KernelStruct = Kernel_{kernel_name}; + + KernelKey key; + + // Signature + key.signature.dtype_a = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; + key.signature.dtype_b = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; + key.signature.dtype_c = {self.tm.DTYPE_TO_DISPATCHER[output_dtype]}; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[0]]}; + key.signature.layout_b = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[1]]}; + key.signature.layout_c = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[2]]}; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "{config.elementwise_op}"; + key.signature.num_d_tensors = {config.num_d_tensors}; + key.signature.structured_sparsity = false; + + // Algorithm + key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}}; + key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, {config.tile.warp_k}}}; + key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}}; + key.algorithm.pipeline = {self.tm.PIPELINE_TO_DISPATCHER[config.trait.pipeline]}; + key.algorithm.scheduler = {self.tm.SCHEDULER_TO_DISPATCHER[config.trait.scheduler]}; + key.algorithm.epilogue = {self.tm.EPILOGUE_TO_DISPATCHER[config.trait.epilogue]}; + key.algorithm.block_size = {config.block_size}; + key.algorithm.double_buffer = {str(config.trait.pipeline == "compv4").lower()}; + key.algorithm.persistent = {str(config.trait.persistent).lower()}; + key.algorithm.preshuffle = {str(config.preshuffle).lower()}; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = {config.num_wave_groups}; + + key.gfx_arch = gfx_arch; + + return std::make_shared>(key, "{kernel_name}"); +}} + +}}}}}} +""" + + +# ============================================================================ +# Main Unified Generator +# ============================================================================ + + +class UnifiedGemmCodegen: + """Unified GEMM code generator - single entry point""" + + def __init__( + self, + output_dir: Path, + datatype: str, + layout: str, + gpu_target: str = "gfx942", + config_file: Optional[Path] = None, + variants: List[GemmVariant] = None, + use_preselected: Optional[str] = None, + enable_arch_filter: bool = True, + kernel_set_name: Optional[str] = None, + ): + self.output_dir = Path(output_dir) + self.datatype = datatype + # Support 3-char (rcr) or 4-char (rcrr) layout codes + # 4th char specifies D tensor layout for multi-d + self.layout = layout[:3] # A, B, C layouts + self.d_layout = ( + layout[3] if len(layout) >= 4 else layout[2] + ) # D layout (default = C layout) + self.gpu_target = gpu_target + self.variants = variants or [GemmVariant.STANDARD] + self.use_preselected = use_preselected + self.kernel_set_name = kernel_set_name + + # Create directories - optionally with kernel set subdirectory + if kernel_set_name: + self.kernel_dir = self.output_dir / kernel_set_name + else: + self.kernel_dir = self.output_dir + self.kernel_dir.mkdir(parents=True, exist_ok=True) + self.wrapper_dir = self.kernel_dir / "dispatcher_wrappers" + self.wrapper_dir.mkdir(parents=True, exist_ok=True) + + # Load configuration + self.config = self._load_config(config_file) + + # Initialize architecture filter for GPU-specific validation + self.arch_filter = None + if enable_arch_filter and HAS_ARCH_FILTER: + try: + self.arch_filter = ArchFilter(gpu_target, strict_mode=False) + log.info(f"Architecture filter enabled for {gpu_target}") + except ValueError as e: + log.warning(f"Could not create arch filter: {e}") + + # Initialize generators (use self.layout which is the 3-char A,B,C layout) + self.ck_gen = CKTileKernelGenerator(datatype, self.layout) + self.disp_gen = DispatcherWrapperGenerator(datatype, self.layout) + + def _load_config(self, config_file: Optional[Path]) -> Dict: + """Load or create default configuration""" + if config_file and config_file.exists(): + with open(config_file) as f: + return json.load(f) + + # Match tile_engine default configs for GEMM/Preshuffle/Multi-D + # See: tile_engine/ops/gemm/configs/default_config.json + # tile_engine/ops/gemm_preshuffle/configs/default_config.json + # tile_engine/ops/gemm_multi_d/configs/default_config.json + return { + "tile_config": { + # tile_m/n/k: 64-256 step 64 = [64, 128, 192, 256] + "tile_m": [64, 128, 192, 256], + "tile_n": [64, 128, 192, 256], + "tile_k": [64, 128, 192, 256], + # warp configs matching tile_engine + "warp_m": [1, 2, 4], + "warp_n": [1, 2, 4], + "warp_k": [1], + # warp_tile configs matching tile_engine + "warp_tile_m": [4, 16, 32], + "warp_tile_n": [16, 32, 64], + "warp_tile_k": [8, 16, 32, 64, 128], + }, + "trait_config": { + "pipeline": ["compv3", "compv4", "mem"], + "epilogue": ["cshuffle", "default"], + "scheduler": ["intrawave", "interwave"], + "pad_m": [False], + "pad_n": [False], + "pad_k": [False], + "persistent": [False, True], + }, + "multi_d_config": { + # Note: Only MultiDAdd and MultiDMultiply are compatible with multi-D GEMM. + # Relu/Gelu are unary ops with signature (y, x), not multi-D signature (e, c, ds...) + "elementwise_ops": ["MultiDAdd", "MultiDMultiply"], + "num_d_tensors": [1, 2], + }, + } + + def generate_all(self, parallel: bool = True) -> Dict: + """Generate all kernels""" + log.info("Generating GEMM kernels:") + log.info(f" Datatype: {self.datatype}") + log.info(f" Layout: {self.layout}") + log.info(f" Variants: {[v.value for v in self.variants]}") + if self.use_preselected: + log.info(f" Using preselected set: {self.use_preselected}") + + results = {"kernels": [], "wrappers": [], "failed": []} + + # Get configurations + if self.use_preselected: + configs = self._get_preselected_configs() + log.info(f" Total configurations: {len(configs)}") + else: + for variant in self.variants: + log.info(f"\nGenerating {variant.value} kernels...") + configs = self._get_configs_for_variant(variant) + log.info(f" Configurations: {len(configs)}") + + if parallel: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self._generate_one, cfg) for cfg in configs + ] + for future in concurrent.futures.as_completed(futures): + try: + k, w = future.result() + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + else: + for cfg in configs: + try: + k, w = self._generate_one(cfg) + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + + # Generate registration header + if results["wrappers"]: + self._generate_registration_header(results["wrappers"]) + + return results + + # Generate from preselected set + if parallel: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(self._generate_one, cfg) for cfg in configs] + for future in concurrent.futures.as_completed(futures): + try: + k, w = future.result() + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + else: + for cfg in configs: + try: + k, w = self._generate_one(cfg) + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + + # Generate registration header + if results["wrappers"]: + self._generate_registration_header(results["wrappers"]) + + return results + + def _get_preselected_configs(self) -> List[KernelConfig]: + """Get preselected kernel configurations""" + try: + from preselected_kernels import get_preselected_set + + return get_preselected_set(self.use_preselected) + except ImportError: + log.warning( + "preselected_kernels module not found, falling back to config-based generation" + ) + return [] + except ValueError as e: + log.error(f"Invalid preselected set: {e}") + return [] + + def _get_configs_for_variant(self, variant: GemmVariant) -> List[KernelConfig]: + """Get all configurations for a variant + + Args: + variant: GEMM variant (STANDARD, PRESHUFFLE, MULTI_D) + + Returns: + List of valid kernel configurations for the variant + """ + configs = [] + + # Get base configs + tile_configs = self._get_tile_configs() + trait_configs = self._get_trait_configs() + + for tile, trait in itertools.product(tile_configs, trait_configs): + # Perform variant-specific architecture validation + if self.arch_filter and HAS_ARCH_FILTER: + if not self._is_tile_arch_valid(tile, variant): + continue + + if variant == GemmVariant.STANDARD: + configs.append(KernelConfig(tile=tile, trait=trait, variant=variant)) + + elif variant == GemmVariant.PRESHUFFLE: + # Preshuffle needs specific pipeline (preshufflev2) and scheduler (default) + # Skip configs that don't use preshuffle-compatible traits + preshuffle_trait = TraitConfig( + pipeline="preshufflev2", + epilogue="cshuffle", + scheduler="default", + pad_m=trait.pad_m, + pad_n=trait.pad_n, + pad_k=trait.pad_k, + persistent=trait.persistent, + ) + # Only generate one preshuffle config per tile (not per trait) + # since preshuffle has fixed pipeline/scheduler + if trait.pipeline == "compv3" and trait.scheduler == "intrawave": + configs.append( + KernelConfig( + tile=tile, + trait=preshuffle_trait, + variant=variant, + preshuffle=True, + ) + ) + + elif variant == GemmVariant.MULTI_D: + multi_d = self.config.get("multi_d_config", {}) + for ew_op, num_d in itertools.product( + multi_d.get("elementwise_ops", ["MultiDAdd"]), + multi_d.get("num_d_tensors", [1]), + ): + configs.append( + KernelConfig( + tile=tile, + trait=trait, + variant=variant, + elementwise_op=ew_op, + num_d_tensors=num_d, + d_layout=self.d_layout, # Use extracted D layout + ) + ) + + return configs + + def _get_tile_configs(self) -> List[TileConfig]: + """Get valid tile configurations, filtered by architecture constraints""" + tc = self.config["tile_config"] + configs = [] + rejected_count = 0 + + for params in itertools.product( + tc["tile_m"], + tc["tile_n"], + tc["tile_k"], + tc["warp_m"], + tc["warp_n"], + tc["warp_k"], + tc["warp_tile_m"], + tc["warp_tile_n"], + tc["warp_tile_k"], + ): + tile = TileConfig(*params) + + # Basic validation + if not tile.is_valid(): + rejected_count += 1 + continue + + # Architecture-specific validation + if self.arch_filter and HAS_ARCH_FILTER: + if not self._is_tile_arch_valid(tile): + rejected_count += 1 + continue + + configs.append(tile) + + if rejected_count > 0: + log.debug(f"Rejected {rejected_count} tile configs for {self.gpu_target}") + + return configs + + def _is_tile_arch_valid( + self, tile: TileConfig, variant: GemmVariant = None + ) -> bool: + """Check if tile configuration is valid for target architecture + + Args: + tile: Tile configuration to validate + variant: GEMM variant (affects operator-specific constraints) + """ + if not self.arch_filter or not HAS_ARCH_FILTER: + return True + + # Determine data types based on self.datatype + # Note: dtype_c is the ACCUMULATOR type, not output type (which may be fp16) + # WMMA instructions on gfx942 always use fp32 accumulator for fp16 inputs + dtype_map = { + "fp16": ("fp16", "fp16", "fp32"), # A=fp16, B=fp16, Acc=fp32 + "bf16": ("bf16", "bf16", "fp32"), # A=bf16, B=bf16, Acc=fp32 + "fp8": ("fp8", "fp8", "fp32"), # A=fp8, B=fp8, Acc=fp32 + "bf8": ("bf8", "bf8", "fp32"), # A=bf8, B=bf8, Acc=fp32 + "int8": ("int8", "int8", "int32"), # A=int8, B=int8, Acc=int32 + } + dtype_a, dtype_b, dtype_c = dtype_map.get( + self.datatype, ("fp16", "fp16", "fp32") + ) + + # Map GEMM variant to operator type for validation + operator = None + pipeline = "compv4" # Default + scheduler = "intrawave" # Default + + if OperatorType is not None and variant is not None: + variant_to_operator = { + GemmVariant.STANDARD: OperatorType.GEMM, + GemmVariant.PRESHUFFLE: OperatorType.GEMM_PRESHUFFLE, + GemmVariant.MULTI_D: OperatorType.GEMM_MULTI_D, + } + operator = variant_to_operator.get(variant, OperatorType.GEMM) + + # Preshuffle requires specific pipeline and scheduler + if variant == GemmVariant.PRESHUFFLE: + pipeline = "preshufflev2" + scheduler = "default" + + # Use preshuffle-specific validation (comprehensive CK-specific checks) + if variant == GemmVariant.PRESHUFFLE: + if not is_preshuffle_config_valid( + tile_m=tile.tile_m, + tile_n=tile.tile_n, + tile_k=tile.tile_k, + warp_m=tile.warp_m, + warp_n=tile.warp_n, + warp_k=tile.warp_k, + warp_tile_m=tile.warp_tile_m, + warp_tile_n=tile.warp_tile_n, + warp_tile_k=tile.warp_tile_k, + datatype=self.datatype, + ): + return False + + return self.arch_filter.is_kernel_valid( + datatype_a=dtype_a, + datatype_b=dtype_b, + datatype_c=dtype_c, + tile_m=tile.tile_m, + tile_n=tile.tile_n, + tile_k=tile.tile_k, + warp_m=tile.warp_m, + warp_n=tile.warp_n, + warp_k=tile.warp_k, + warp_tile_m=tile.warp_tile_m, + warp_tile_n=tile.warp_tile_n, + warp_tile_k=tile.warp_tile_k, + pipeline=pipeline, + scheduler=scheduler, + layout=self.layout, + operator=operator, + ) + + def _get_trait_configs(self) -> List[TraitConfig]: + """Get valid trait configurations, filtered by architecture constraints""" + tc = self.config["trait_config"] + configs = [] + rejected_count = 0 + + for params in itertools.product( + tc["pipeline"], + tc["epilogue"], + tc["scheduler"], + tc["pad_m"], + tc["pad_n"], + tc["pad_k"], + tc["persistent"], + ): + trait = TraitConfig(*params) + + # Basic trait validation (unsupported combinations) + if not trait.is_valid(): + rejected_count += 1 + continue + + configs.append(trait) + + if rejected_count > 0: + log.debug(f"Rejected {rejected_count} trait configs") + + return configs + + def _generate_one(self, config: KernelConfig) -> Tuple[str, str]: + """Generate one kernel and wrapper""" + kernel_name = KernelNaming.generate(config, self.datatype, self.layout) + + # Generate CK Tile kernel + kernel_code = self.ck_gen.generate(config) + kernel_path = self.kernel_dir / f"{kernel_name}.hpp" + kernel_path.write_text(kernel_code) + + # Generate dispatcher wrapper + wrapper_code = self.disp_gen.generate(config, kernel_path, self.kernel_dir) + wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp" + wrapper_path.write_text(wrapper_code) + + # Generate .cpp compilation unit for per-kernel parallel builds + cpp_path = self.kernel_dir / f"{kernel_name}.cpp" + cpp_code = f'''// SPDX-License-Identifier: MIT +// Auto-generated compilation unit for: {kernel_name} +// Enables per-kernel parallel compilation with make -j + +#include "{kernel_name}.hpp" + +namespace ck_tile {{ namespace generated {{ + volatile bool _{kernel_name.replace("-", "_")}_loaded = true; +}} }} +''' + cpp_path.write_text(cpp_code) + + return str(kernel_path), str(wrapper_path) + + def _generate_registration_header(self, wrapper_paths: List[str]): + """Generate master registration header""" + kernel_names = [ + Path(w).stem.replace("dispatcher_wrapper_", "") for w in wrapper_paths + ] + + includes = "\n".join( + [f'#include "dispatcher_wrapper_{n}.hpp"' for n in kernel_names] + ) + registrations = "\n ".join( + [ + f"registry.register_kernel(generated::make_{n}(gfx_arch), priority);" + for n in kernel_names + ] + ) + + content = f"""// SPDX-License-Identifier: MIT +// Auto-generated master registration +#pragma once + +#include "ck_tile/dispatcher.hpp" +{includes} + +namespace ck_tile {{ +namespace dispatcher {{ + +using ::ck_tile::dispatcher::Registry; +using Priority = ::ck_tile::dispatcher::Registry::Priority; + +inline void register_all_tile_gemm_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = Registry::instance(); + {registrations} +}} + +inline std::size_t get_tile_gemm_kernel_count() {{ return {len(kernel_names)}; }} + +}}}} +""" + + reg_path = self.wrapper_dir / "register_all_kernels.hpp" + reg_path.write_text(content) + logging.info(f"Generated registration header: {reg_path}") + + +# ============================================================================ +# CLI +# ============================================================================ + + +def _show_arch_info(gpu_target: str, datatype: str): + """Display supported configurations for a GPU architecture""" + if not HAS_ARCH_FILTER: + print("Architecture filter module not available") + return + + try: + from arch_filter import ( + get_supported_archs, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + LDS_CAPACITY_LIMITS, + TRAIT_UNSUPPORTED_COMBINATIONS, + ) + + print(f"\n=== Architecture Info for {gpu_target} ===\n") + + # Supported architectures + print(f"Supported GPUs: {get_supported_archs()}") + + # Warp configurations + warp_cfgs = WARP_SUPPORTED_COMBINATIONS.get(gpu_target, []) + print("\nWarp configurations [warp_m, warp_n, warp_k]:") + for cfg in warp_cfgs: + print(f" {cfg}") + + # Warp tile configurations for data type + dtype_map = { + "fp16": "fp16_fp16_fp16", + "bf16": "bf16_bf16_bf16", + "fp8": "fp8_fp8_fp16", + "bf8": "bf8_bf8_fp16", + "int8": "int8_int8_int32", + } + dtype_key = dtype_map.get(datatype, "fp16_fp16_fp16") + + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_target, {}) + warp_tiles = gpu_combos.get(dtype_key, []) + print( + f"\nWarp tile configurations for {dtype_key} [warp_tile_m, warp_tile_n, warp_tile_k]:" + ) + for cfg in warp_tiles: + print(f" {cfg}") + + # All supported data types + print(f"\nAll supported data types on {gpu_target}:") + for dtype in gpu_combos.keys(): + print(f" {dtype}") + + # LDS limits + print("\nLDS capacity limits:") + for pipeline, limit in LDS_CAPACITY_LIMITS.items(): + print(f" {pipeline}: {limit // 1024}KB") + + # Unsupported trait combinations + print("\nUnsupported trait combinations (pipeline, epilogue, scheduler):") + for combo in TRAIT_UNSUPPORTED_COMBINATIONS: + print(f" {combo}") + + print() + + except Exception as e: + print(f"Error showing arch info: {e}") + + +def main(): + parser = argparse.ArgumentParser( + description="Unified GEMM Code Generator - Single Source of Truth" + ) + parser.add_argument( + "--output-dir", type=Path, required=True, help="Output directory" + ) + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "bf16", "fp32", "fp8", "bf8", "int8", "pk_fp4"], + help="Data type (fp16, bf16, fp32, fp8, bf8, int8, pk_fp4)", + ) + parser.add_argument( + "--layout", + type=str, + default="rcr", + help="Layout (e.g., rcr for A=row, B=col, C=row; or rcrr for multi-d with D=row)", + ) + parser.add_argument( + "--gpu-target", + type=str, + default="gfx942", + help="Target GPU (gfx90a, gfx942, gfx950, gfx1201)", + ) + parser.add_argument("--config", type=Path, help="Configuration JSON file") + parser.add_argument( + "--variants", + nargs="+", + choices=["standard", "preshuffle", "multi_d"], + default=["standard"], + help="Variants to generate", + ) + parser.add_argument( + "--preselected", + type=str, + help="Use preselected kernel set (e.g., fp16_rcr_essential)", + ) + parser.add_argument( + "--no-parallel", action="store_true", help="Disable parallel generation" + ) + parser.add_argument( + "--register", action="store_true", help="Generate dispatcher registration code" + ) + parser.add_argument( + "--no-arch-filter", + action="store_true", + help="Disable architecture-specific kernel filtering", + ) + parser.add_argument( + "--show-arch-info", + action="store_true", + help="Show supported configurations for target GPU and exit", + ) + parser.add_argument( + "--kernel-set", + type=str, + help="Kernel set name (creates subdirectory for organization)", + ) + parser.add_argument( + "--tile-config-json", + type=str, + help="JSON string specifying exact tile configuration (for minimal builds)", + ) + + args = parser.parse_args() + + # Handle inline tile config JSON for minimal/single-kernel builds + if args.tile_config_json: + try: + cfg = json.loads(args.tile_config_json) + + # Build proper config structure + full_config = {} + + # Extract tile config + tile_keys = [ + "tile_m", + "tile_n", + "tile_k", + "warp_m", + "warp_n", + "warp_k", + "warp_tile_m", + "warp_tile_n", + "warp_tile_k", + "block_size", + ] + tile_config = {k: cfg[k] for k in tile_keys if k in cfg} + if tile_config: + full_config["tile_config"] = tile_config + + # Extract trait config + trait_keys = ["pipeline", "epilogue", "scheduler"] + trait_config = {k: cfg[k] for k in trait_keys if k in cfg} + # Add default pad/persistent values + trait_config.setdefault("pad_m", [False]) + trait_config.setdefault("pad_n", [False]) + trait_config.setdefault("pad_k", [False]) + trait_config.setdefault("persistent", [False]) + if trait_config: + full_config["trait_config"] = trait_config + + # Extract multi_d config (for multi_d variant) + if "elementwise_ops" in cfg or "num_d_tensors" in cfg: + multi_d_config = {} + if "elementwise_ops" in cfg: + multi_d_config["elementwise_ops"] = cfg["elementwise_ops"] + if "num_d_tensors" in cfg: + multi_d_config["num_d_tensors"] = cfg["num_d_tensors"] + full_config["multi_d_config"] = multi_d_config + + # Use already structured config if provided + if "tile_config" in cfg: + full_config = cfg + + # Write to temp file and use as config + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as f: + json.dump(full_config, f) + args.config = Path(f.name) + except json.JSONDecodeError as e: + logging.error(f"Invalid tile-config-json: {e}") + return 1 + except KeyError as e: + logging.error(f"Missing required key in tile-config-json: {e}") + return 1 + + # Show architecture info if requested + if args.show_arch_info: + _show_arch_info(args.gpu_target, args.datatype) + return 0 + + variants = [GemmVariant(v) for v in args.variants] if not args.preselected else None + + codegen = UnifiedGemmCodegen( + output_dir=args.output_dir, + datatype=args.datatype, + layout=args.layout, + gpu_target=args.gpu_target, + config_file=args.config, + variants=variants, + use_preselected=args.preselected, + enable_arch_filter=not args.no_arch_filter, + kernel_set_name=args.kernel_set, + ) + + results = codegen.generate_all(parallel=not args.no_parallel) + + logging.info("\n✅ Generation complete!") + logging.info(f" Kernels: {len(results['kernels'])}") + logging.info(f" Wrappers: {len(results['wrappers'])}") + logging.info(f" Failed: {len(results['failed'])}") + + if results["failed"]: + logging.error(f"\nFailed kernels: {len(results['failed'])}") + for err in results["failed"][:5]: + logging.error(f" {err}") + + # Generate dispatcher registration if requested + if args.register: + logging.info("\n📝 Generating dispatcher registration code...") + try: + from generate_dispatcher_registration import ( + scan_generated_headers, + generate_registration_header, + generate_registration_cpp, + ) + + kernels = scan_generated_headers(args.output_dir) + reg_dir = args.output_dir / "registration" + reg_dir.mkdir(exist_ok=True) + + generate_registration_header( + kernels, reg_dir / "dispatcher_registration.hpp" + ) + generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp") + + logging.info(f"✓ Generated registration code for {len(kernels)} kernels") + except Exception as e: + logging.error(f"Failed to generate registration code: {e}") + return 1 + + return 0 if not results["failed"] else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/dispatcher/examples/CMakeLists.txt b/dispatcher/examples/CMakeLists.txt new file mode 100644 index 0000000000..0359eb0d8d --- /dev/null +++ b/dispatcher/examples/CMakeLists.txt @@ -0,0 +1,448 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +cmake_minimum_required(VERSION 3.16) + +# Get processor count for parallel builds +include(ProcessorCount) +ProcessorCount(NPROC) +if(NPROC EQUAL 0) + set(NPROC 4) +endif() + +# GPU target architecture (passed from command line or default to gfx942) +if(NOT DEFINED GPU_TARGETS OR GPU_TARGETS STREQUAL "") + set(GPU_TARGETS "gfx942" CACHE STRING "GPU architecture target") +endif() +# Extract first target if multiple are provided (we only support single target builds) +string(REPLACE ";" " " GPU_TARGETS_SPACE "${GPU_TARGETS}") +string(REPLACE " " ";" GPU_TARGETS_LIST "${GPU_TARGETS_SPACE}") +list(GET GPU_TARGETS_LIST 0 GPU_TARGET) +message(STATUS "Building for GPU target: ${GPU_TARGET}") + +# NOTE: Per-kernel compilation is now automatic via declarative examples +# Each example generates only its declared kernels (from DECL_KERNEL_SET) + +# Link to dispatcher library +link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../build) + +# ============================================================================= +# Kernel Output Directory +# ============================================================================= + +set(KERNEL_OUTPUT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels") +file(MAKE_DIRECTORY ${KERNEL_OUTPUT_DIR}) + +# ============================================================================= +# Kernel Generation Targets (run during 'make', not 'cmake') +# ============================================================================= + +# Sentinel files to track generation +set(GEMM_SENTINEL "${KERNEL_OUTPUT_DIR}/.gemm_generated") + +# Generate GEMM kernels (standard + preshuffle + multi_d) - runs with internal parallelism +# Note: 4-char layout "rcrr" means A=row, B=col, C=row, D=row (for multi-d) +add_custom_command( + OUTPUT ${GEMM_SENTINEL} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcrr --variants standard preshuffle multi_d + --output ${KERNEL_OUTPUT_DIR} + COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating GEMM kernels (fp16, rcrr, standard + preshuffle + multi_d) with internal parallelism..." + VERBATIM +) + +add_custom_target(generate_gemm_kernels + DEPENDS ${GEMM_SENTINEL} + COMMENT "GEMM kernel generation target" +) + +# Alias for generate_all_kernels (GEMM only now) +add_custom_target(generate_all_kernels + DEPENDS generate_gemm_kernels +) + +# ============================================================================= +# Per-Kernel Compilation (Maximum Parallelism) +# ============================================================================= +# Enable with: cmake -DPER_KERNEL_COMPILATION=ON +# +# This creates ONE translation unit per kernel, enabling: +# 1. Maximum parallelism with make -j$(nproc) +# 2. Per-kernel build progress: "[1/128] Building kernel: gemm_fp16_128x128" +# 3. Incremental rebuilds (only changed kernels recompile) +# 4. Fine-grained build time analysis +# +# Build process: +# 1. Generate kernel headers (.hpp) +# 2. Generate wrapper files (.cpp) - one per kernel +# 3. Compile each wrapper in parallel +# 4. Link all objects into libdispatcher_kernels.so +# +# Example output: +# [ 1/128] Building kernel: gemm_fp16_rcr_128x128x32 +# [ 2/128] Building kernel: gemm_fp16_rcr_256x256x64 +# ... +# [128/128] Linking: libdispatcher_kernels.so +# ============================================================================= + +set(WRAPPER_DIR "${CMAKE_BINARY_DIR}/kernel_wrappers") +set(WRAPPER_SENTINEL "${WRAPPER_DIR}/.wrappers_generated") + +# Target: Generate wrapper .cpp files (one per kernel) +add_custom_command( + OUTPUT ${WRAPPER_SENTINEL} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/generate_kernel_wrappers.py + --kernel-dir ${KERNEL_OUTPUT_DIR} + --output-dir ${WRAPPER_DIR} + --generate-makefile + --generate-cmake + COMMAND ${CMAKE_COMMAND} -E touch ${WRAPPER_SENTINEL} + DEPENDS ${GEMM_SENTINEL} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating per-kernel wrapper .cpp files..." + VERBATIM +) + +add_custom_target(generate_kernel_wrappers + DEPENDS ${WRAPPER_SENTINEL} + COMMENT "Kernel wrapper generation target" +) + +# Target: Build kernels using generated Makefile (true per-kernel progress) +add_custom_target(build_kernels_parallel + COMMAND ${CMAKE_COMMAND} -E echo "Building kernels with per-kernel progress..." + COMMAND make -C ${WRAPPER_DIR} -j${NPROC} 2>&1 | grep -E "^\\[|Built|Linking|Error" + DEPENDS generate_kernel_wrappers + WORKING_DIRECTORY ${WRAPPER_DIR} + COMMENT "Compiling kernels in parallel (one translation unit per kernel)..." + VERBATIM +) + +# Global kernel build (optional - prefer per-example builds for minimal compilation) +# This builds ALL kernels into a shared library - use for Python bindings or full library +# For C++ examples, use declarative approach which builds only needed kernels +add_custom_target(dispatcher_kernels + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/parallel_kernel_builder.py + --kernel-dir ${KERNEL_OUTPUT_DIR} + --output-dir ${CMAKE_BINARY_DIR} + --include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include" + --jobs ${NPROC} + DEPENDS generate_all_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts + COMMENT "Building ALL kernels in parallel (prefer per-example builds for minimal compilation)..." + VERBATIM +) + +# ============================================================================= +# Force regeneration targets (useful when you want to regenerate) +# ============================================================================= + +add_custom_target(regenerate_gemm_kernels + COMMAND ${CMAKE_COMMAND} -E remove -f ${GEMM_SENTINEL} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcr --variants standard preshuffle multi_d + --output ${KERNEL_OUTPUT_DIR} + COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Force regenerating GEMM kernels (standard + preshuffle + multi_d)..." + VERBATIM +) + +add_custom_target(regenerate_all_kernels + DEPENDS regenerate_gemm_kernels +) + +# Clean all per-example kernel directories +add_custom_target(clean_example_kernels + COMMAND ${CMAKE_COMMAND} -E echo "Removing per-example kernel directories..." + COMMAND find ${CMAKE_BINARY_DIR} -maxdepth 1 -type d -name "*_kernels" -exec rm -rf {} + + COMMENT "Cleaning all per-example kernel directories..." + VERBATIM +) + +# ============================================================================= +# Helper function to add a GPU example with force-included kernel +# ============================================================================= + +# Helper for GPU examples that use the dispatcher registry +# KERNEL_HEADER can be: +# - A registration header (register_all_kernels.hpp) - included directly in source +# - A specific kernel header - force-included via compiler flag +function(add_gpu_example NAME SOURCE KERNEL_HEADER) + add_executable(${NAME} ${SOURCE}) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include + ${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include + ${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels + ${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels/dispatcher_wrappers # Wrapper headers + ) + + # Check if using registration header (no force-include needed) + get_filename_component(HEADER_NAME ${KERNEL_HEADER} NAME) + if(HEADER_NAME STREQUAL "register_all_kernels.hpp") + # Registration header - examples include it directly + target_compile_options(${NAME} PRIVATE + -DGEMM_KERNEL_AVAILABLE=1 + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + else() + # Specific kernel header - force-include it + target_compile_options(${NAME} PRIVATE + -include ${KERNEL_HEADER} + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + endif() + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() +endfunction() + +# Helper for standalone GPU examples (instantiate kernel directly, no pre-generated header) +function(add_standalone_gpu_example NAME SOURCE) + add_executable(${NAME} ${SOURCE}) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include + ${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include + ${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels (optional) + ) + + target_compile_options(${NAME} PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() +endfunction() + +# Helper for declarative examples (configuration demo, still needs HIP compiler for CK headers) +function(add_declarative_example NAME SOURCE) + add_executable(${NAME} ${SOURCE}) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ) + + target_compile_options(${NAME} PRIVATE + -Wno-float-equal + -Wno-unused-variable + -Wno-undefined-func-template + -mllvm -enable-noalias-to-md-conversion=0 + ) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() +endfunction() + +# ============================================================================= +# GEMM Examples +# ============================================================================= + +# Per-example kernel directories are created from DECL_KERNEL_SET declarations +# Each example gets its own: build/_kernels/ +# This prevents clashes during parallel compilation of multiple examples. + +# Helper function to add example with declarative kernel support +# Parses DECL_KERNEL_SET from source and generates ONLY the declared kernels +# This enables minimal builds: only kernels needed by this example are generated +# +# Key features: +# - Per-example kernel directories: build/_kernels/ (no clashes) +# - Automatic header inclusion: No hardcoded #include needed in source +# - Minimal builds: Only declared kernels are generated +# - Auto-regeneration: Kernels regenerated if directory missing +# - Parallel compilation: Each kernel is a separate translation unit +function(add_declarative_gpu_example NAME SOURCE) + set(EXAMPLE_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE}") + get_filename_component(EXAMPLE_STEM ${SOURCE} NAME_WE) + + # Per-example kernel directories + set(EXAMPLE_KERNEL_DIR "${CMAKE_BINARY_DIR}/${NAME}_kernels") + set(EXAMPLE_HEADER "${EXAMPLE_KERNEL_DIR}/${EXAMPLE_STEM}_kernels.hpp") + set(EXAMPLE_LIB "${EXAMPLE_KERNEL_DIR}/lib${NAME}_kernels.a") + set(EXAMPLE_SENTINEL "${EXAMPLE_KERNEL_DIR}/.generated") + + # Generate AND compile kernels in parallel at make time + # This avoids slow cmake and gets per-kernel progress + add_custom_command( + OUTPUT ${EXAMPLE_SENTINEL} ${EXAMPLE_LIB} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/example_kernel_builder.py + ${EXAMPLE_SOURCE} + --output-dir ${EXAMPLE_KERNEL_DIR} + --include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include" + --gpu-target ${GPU_TARGET} + --jobs ${NPROC} + --target-name ${NAME} + COMMAND ${CMAKE_COMMAND} -E touch ${EXAMPLE_SENTINEL} + DEPENDS ${EXAMPLE_SOURCE} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts + COMMENT "[${NAME}] Generating and compiling kernels from DECL_KERNEL_SET..." + VERBATIM + ) + + add_custom_target(generate_${NAME}_kernels DEPENDS ${EXAMPLE_SENTINEL}) + + # Add the executable + add_executable(${NAME} ${SOURCE}) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + # Link against the per-example kernel library + target_link_libraries(${NAME} PRIVATE ${EXAMPLE_LIB}) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${EXAMPLE_KERNEL_DIR} + ${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers + ) + + # Force-include the generated registration header + target_compile_options(${NAME} PRIVATE + -include ${EXAMPLE_HEADER} + -DGEMM_KERNEL_AVAILABLE=1 + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() + + # Only depends on generating THIS example's kernels + add_dependencies(${NAME} generate_${NAME}_kernels) +endfunction() + +# GEMM C++ examples with declarative kernel support +# Each example's C++ code contains DECL_KERNEL_SET which declares needed kernels +add_declarative_gpu_example(gemm_01_basic gemm/cpp/01_basic_gemm.cpp) +add_declarative_gpu_example(gemm_02_multi_size gemm/cpp/02_multi_size.cpp) +add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_validation.cpp) +add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp) +add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp) +add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp) + +# ============================================================================= +# GEMM Python Library - Single Fallback Kernel +# ============================================================================= + +# Generate a single fallback kernel for the Python library (fp16, rcr, compv4) +set(GEMM_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/gemm_python_fallback") +set(GEMM_FALLBACK_KERNEL "${GEMM_FALLBACK_KERNEL_DIR}/gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16.hpp") + +# Tile config JSON for single kernel generation +set(GEMM_FALLBACK_TILE_CONFIG "{\"tile_m\":[128],\"tile_n\":[128],\"tile_k\":[32],\"warp_m\":[2],\"warp_n\":[2],\"warp_k\":[1],\"warp_tile_m\":[32],\"warp_tile_n\":[32],\"warp_tile_k\":[16],\"pipeline\":[\"compv4\"],\"scheduler\":[\"intrawave\"],\"epilogue\":[\"cshuffle\"]}") + +# Generate single fallback kernel (not all 6000+ kernels) +add_custom_command( + OUTPUT ${GEMM_FALLBACK_KERNEL} + COMMAND ${CMAKE_COMMAND} -E make_directory ${GEMM_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcr --variants standard + --gpu-target ${GPU_TARGET} + --output-dir ${GEMM_FALLBACK_KERNEL_DIR} + --tile-config-json "${GEMM_FALLBACK_TILE_CONFIG}" + COMMENT "Generating single fallback GEMM kernel for Python library" + VERBATIM +) + +add_custom_target(generate_gemm_fallback_kernel DEPENDS ${GEMM_FALLBACK_KERNEL}) + +# GEMM dynamic library for Python +add_library(dispatcher_gemm_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/gemm_ctypes_lib.cpp) +target_link_libraries(dispatcher_gemm_lib PRIVATE ck_tile_dispatcher) +target_include_directories(dispatcher_gemm_lib PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${GEMM_FALLBACK_KERNEL_DIR} +) +target_compile_options(dispatcher_gemm_lib PRIVATE + -DCK_TILE_SINGLE_KERNEL_INCLUDE + -include ${GEMM_FALLBACK_KERNEL} + -DGFX_ARCH="${GPU_TARGET}" + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +if(hip_FOUND) + target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device hip::host) +endif() +add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel) + +message(STATUS "GEMM examples configured - kernels will be generated during 'make'") + +# Convenience target to build all Python ctypes libraries +add_custom_target(python_libs + DEPENDS dispatcher_gemm_lib + COMMENT "Building Python ctypes libraries (GEMM)" +) + +# ============================================================================= +# Per-Architecture Kernel Generation Targets +# ============================================================================= + +set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030) + +foreach(ARCH ${SUPPORTED_GPU_ARCHS}) + # GEMM kernels for this arch + add_custom_target(generate_gemm_kernels_${ARCH} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcr --gpu-target ${ARCH} + --output ${KERNEL_OUTPUT_DIR} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating GEMM kernels for ${ARCH}..." + VERBATIM + ) + + # Alias for kernels (GEMM only now) + add_custom_target(generate_kernels_${ARCH} + DEPENDS generate_gemm_kernels_${ARCH} + COMMENT "Generating all kernels for ${ARCH}..." + ) +endforeach() + +# ============================================================================= +# Summary +# ============================================================================= + +message(STATUS "") +message(STATUS "=== Dispatcher Examples Configuration ===") +message(STATUS "") +message(STATUS "Kernels will be generated automatically during 'make'") +message(STATUS " Generated to: ${KERNEL_OUTPUT_DIR}") +message(STATUS "") +message(STATUS "Build targets:") +message(STATUS " make - Build all examples (generates kernels first)") +message(STATUS " make python_libs - Build Python ctypes libraries") +message(STATUS " make generate_all_kernels - Generate all kernels only") +message(STATUS " make regenerate_all_kernels - Force regenerate all kernels") +message(STATUS "") +message(STATUS "Per-architecture targets:") +message(STATUS " make generate_kernels_ - Generate for specific arch") +message(STATUS " Supported archs: ${SUPPORTED_GPU_ARCHS}") +message(STATUS "") diff --git a/dispatcher/examples/README.md b/dispatcher/examples/README.md new file mode 100644 index 0000000000..fdee9c3583 --- /dev/null +++ b/dispatcher/examples/README.md @@ -0,0 +1,210 @@ +# CK Tile Dispatcher Examples + +Comprehensive examples for GEMM operations with GPU execution. + +> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference. + +--- + +## Quick Start + +### Step 1: Build + +```bash +cd /path/to/composable_kernel/dispatcher +mkdir -p build && cd build + +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Build everything (C++ examples + Python libraries) +make -j$(nproc) + +# Or build ONLY Python libraries (faster) +make python_libs -j$(nproc) +``` + +### Step 2: Run C++ Examples + +```bash +cd build/examples + +# GEMM +./gemm_01_basic +./gemm_02_multi_size +./gemm_03_benchmark_validation +./gemm_04_heuristics +./gemm_05_json_export +./gemm_06_multi_registry +``` + +### Step 3: Run Python Examples + +```bash +cd /path/to/composable_kernel/dispatcher + +# GEMM +python3 examples/gemm/python/01_basic_gemm.py +python3 examples/gemm/python/04_validation.py +python3 examples/gemm/python/07_stress_test.py +python3 examples/gemm/python/08_heuristics.py +``` + +--- + +## Directory Structure + +``` +examples/ +├── gemm/ +│ ├── cpp/ # 6 C++ GEMM examples +│ └── python/ # 11 Python GEMM examples +│ +└── README.md +``` + +--- + +## GEMM Examples + +### C++ Examples + +| # | Example | Description | +|---|---------|-------------| +| 01 | `gemm_01_basic` | Basic GEMM with declarative API, autofill, autocorrect | +| 02 | `gemm_02_multi_size` | Wildcard expansion for multiple configurations | +| 03 | `gemm_03_benchmark_validation` | Performance benchmarking with CPU/GPU validation | +| 04 | `gemm_04_heuristics` | Heuristic-based kernel selection | +| 05 | `gemm_05_json_export` | Registry JSON export for external tools | +| 06 | `gemm_06_multi_registry` | Multiple registries with named kernel sets | + +**Details:** [gemm/cpp/README.md](gemm/cpp/README.md) + +--- + +### Python Examples + +| # | Example | Description | +|---|---------|-------------| +| 01 | `01_basic_gemm.py` | Basic GEMM with multi-kernel support | +| 02 | `02_batch_gemm.py` | Batched GEMM operations | +| 03 | `03_benchmark.py` | Performance benchmarking | +| 04 | `04_validation.py` | CPU reference validation | +| 05 | `05_numpy_integration.py` | NumPy array integration | +| 06 | `06_json_export.py` | Registry JSON export | +| 07 | `07_stress_test.py` | Multi-kernel stress testing (48 configs) | +| 08 | `08_heuristics.py` | Heuristic-based kernel selection (24 configs) | +| 09 | `09_multi_registry.py` | Multiple registries | +| 10 | `10_advanced_benchmark.py` | Advanced benchmark with full control | +| 11 | `11_json_import.py` | Import kernels from JSON | + +**Details:** [gemm/python/README.md](gemm/python/README.md) + +--- + +## Key Features + +### Declarative Kernel API + +Both C++ and Python examples use a declarative approach: + +**C++ (DECL_KERNEL_SET macro):** +```cpp +DECL_KERNEL_SET(my_kernels, + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16) + .pipeline("compv4").scheduler("intrawave"), + "gfx942" + ) +); +``` + +**Python (KernelConfig):** +```python +config = KernelConfig( + tile_m=256, tile_n=256, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave" +) +``` + +### Autofill and Autocorrect + +The build system automatically: +- **Autofills** missing parameters with sensible defaults +- **Autocorrects** invalid parameters based on architecture constraints +- **Expands** wildcards (`*`, `-1`, `ANY_INT`) to all valid configurations + +### Architecture Filtering + +Kernel configurations are validated against GPU architecture constraints: +- Tile divisibility requirements +- Warp tile constraints +- Pipeline compatibility + +Invalid configurations are automatically pruned during code generation. + +--- + +## Validation Examples + +### C++ Validation + +```bash +./gemm_03_benchmark_validation --verify 1 # GEMM with CPU reference +./gemm_03_benchmark_validation --verify 2 # GEMM with GPU reference +``` + +### Python Validation + +```bash +python3 examples/gemm/python/04_validation.py +python3 examples/gemm/python/07_stress_test.py # Multi-kernel validation +``` + +--- + +## Troubleshooting + +### Python: Library not found + +```bash +# Run from dispatcher directory +cd /path/to/composable_kernel/dispatcher +python3 examples/gemm/python/01_basic_gemm.py +``` + +### C++: Executables not found + +```bash +# Build with examples enabled +cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON +make -j$(nproc) + +# Run from build/examples +cd build/examples +./gemm_01_basic +``` + +### GPU not detected + +```bash +rocminfo | grep "Name:" +# Should show: gfx942, gfx90a, etc. +``` + +--- + +## Archived Examples + +Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`: +- `examples/conv/cpp/` - 11 C++ convolution examples +- `examples/conv/python/` - 14 Python convolution examples + +See the archive for convolution functionality reference. diff --git a/dispatcher/examples/gemm/cpp/01_basic_gemm.cpp b/dispatcher/examples/gemm/cpp/01_basic_gemm.cpp new file mode 100644 index 0000000000..80b584a842 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/01_basic_gemm.cpp @@ -0,0 +1,243 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 01: Basic GEMM - Autofill, Autocorrect, and Full Declaration + * + * Demonstrates THREE declaration patterns: + * + * 1. AUTOFILL: Minimal declaration - missing params filled with defaults + * .add(Signature().dtype("fp16").layout("rcr"), + * Algorithm().tile(128,128,64).pipeline("compv3").scheduler("intrawave"), + * "gfx942") + * -> wave(2,2,1), warp(32,32,16), epilogue("cshuffle") added automatically + * + * 2. AUTOCORRECT: Invalid params corrected to valid values + * .add(..., Algorithm().wave(1,1,1)...) + * -> wave(1,1,1) is invalid for gfx942, corrected to wave(2,2,1) + * + * 3. FULL: All parameters explicitly specified + * .add(..., Algorithm().tile().wave().warp().pipeline().scheduler().epilogue()...) + * + * Build: cd dispatcher/build && cmake .. && make gemm_01_basic + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// THREE KERNEL DECLARATION PATTERNS +// ============================================================================= + +DECL_KERNEL_SET( + basic_gemm_kernels, + // ------------------------------------------------------------------------- + // Pattern 1: AUTOFILL - Minimal declaration + // Only specify: dtype, layout, tile, pipeline, scheduler + // Auto-filled: wave(2,2,1), warp(32,32,16), epilogue("cshuffle"), pad(false,false,false) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) // Required + .pipeline("compv3") // Required + .scheduler("intrawave"), // Required + "gfx942") + + // ------------------------------------------------------------------------- + // Pattern 2: AUTOCORRECT - Invalid wave config + // wave(1,1,1) is invalid for gfx942 WMMA, corrected to wave(2,2,1) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) // Different tile_k to make unique kernel + .wave(1, 1, 1) // INVALID: autocorrected to (2,2,1) + .warp(32, 32, 16) // Valid warp for 128x128 tile + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + + // ------------------------------------------------------------------------- + // Pattern 3: FULL - All parameters explicitly specified + // No autofill or autocorrect needed + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) // Explicit tile + .wave(2, 2, 1) // Explicit wave (valid) + .warp(16, 16, 32) // Explicit warp tile + .pipeline("compv3") // Explicit pipeline + .scheduler("intrawave") // Explicit scheduler + .epilogue("cshuffle") // Explicit epilogue + .pad(false, false, false), // Explicit padding + "gfx942")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 01: GEMM Autofill/Autocorrect/Full", + "Three kernel declaration patterns"); + args.add_flag("--list", "List registered kernels"); + args.add_flag("--list-verbose", "List registered kernels with full details"); + args.add_option("--size", "1024", "Problem size MxNxK"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 01: GEMM Declaration Patterns"); + + // ========================================================================= + // Show the Three Patterns + // ========================================================================= + std::cout << "\nTHREE DECLARATION PATTERNS:\n"; + std::cout << "============================\n\n"; + + std::cout << "1. AUTOFILL (minimal declaration):\n"; + std::cout << " .add(Signature().dtype(\"fp16\").layout(\"rcr\"),\n"; + std::cout + << " Algorithm().tile(128,128,64).pipeline(\"compv3\").scheduler(\"intrawave\"),\n"; + std::cout << " \"gfx942\")\n"; + std::cout << " -> Auto-filled: wave(2,2,1), warp(32,32,16), epilogue(\"cshuffle\")\n\n"; + + std::cout << "2. AUTOCORRECT (invalid params fixed):\n"; + std::cout << " .add(..., Algorithm().wave(1,1,1)...)\n"; + std::cout << " -> wave(1,1,1) invalid for gfx942, corrected to wave(2,2,1)\n\n"; + + std::cout << "3. FULL (all params explicit):\n"; + std::cout << " .add(..., " + "Algorithm().tile().wave().warp().pipeline().scheduler().epilogue().pad()...)\n"; + std::cout << " -> No changes needed\n\n"; + + std::string gfx_arch = args.get("--arch", "gfx942"); + + // ========================================================================= + // Step 1: Show Declared Kernel Sets + // ========================================================================= + std::cout << "Step 1: Declared Kernel Sets\n"; + KernelSetRegistry::instance().print(); + + const auto& decl_set = KernelSetRegistry::instance().get("basic_gemm_kernels"); + std::cout << " 'basic_gemm_kernels': " << decl_set.size() << " declaration(s)\n"; + + // ========================================================================= + // Step 2: Create Registry and Register Kernels + // ========================================================================= + std::cout << "\nStep 2: Register Kernels\n"; + + Registry registry; + // Use generic macro + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + // List kernels if requested + if(args.has("--list") || args.has("--list-verbose")) + { + std::cout << "\n"; + print_registered_kernels(registry, std::cout, args.has("--list-verbose")); + return 0; + } + + // ========================================================================= + // Step 3: Create Dispatcher + // ========================================================================= + std::cout << "\nStep 3: Create Dispatcher\n"; + Dispatcher dispatcher(®istry); + + // ========================================================================= + // Step 4: Setup Problem + // ========================================================================= + int size = args.get_int("--size", 1024); + const int M = size, N = size, K = size; + + std::cout << "\nStep 4: Setup Problem (" << M << "x" << N << "x" << K << ")\n"; + + Problem problem(M, N, K); + + using DataType = ck_tile::fp16_t; + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + // ========================================================================= + // Step 5: Select and Run + // ========================================================================= + std::cout << "\nStep 5: Select and Run\n"; + + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "ERROR: No kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->get_name() << "\n"; + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n"; + + // ========================================================================= + // Step 6: Verify + // ========================================================================= + std::cout << "\nStep 6: Verify\n"; + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + + const float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < M * N; ++i) + { + if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f) + ++errors; + } + + bool passed = (errors == 0); + std::cout << " Expected: " << expected << ", Errors: " << errors << "\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + + // ========================================================================= + // Summary + // ========================================================================= + print_separator(); + std::cout << "DECLARATION PATTERNS SUMMARY:\n"; + print_separator(); + std::cout << R"( + 1. AUTOFILL: Specify only required params, system fills defaults + - Useful for quick prototyping + - Guarantees valid configuration + + 2. AUTOCORRECT: System validates and fixes invalid params + - wave(1,1,1) -> wave(2,2,1) on gfx942 + - Invalid pipeline/scheduler combos fixed + - Logs corrections for debugging + + 3. FULL: All params explicit - no changes made + - Full control over configuration + - Best for production/tuning +)"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/02_multi_size.cpp b/dispatcher/examples/gemm/cpp/02_multi_size.cpp new file mode 100644 index 0000000000..5e620209f4 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/02_multi_size.cpp @@ -0,0 +1,215 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 02: Multi-Size GEMM with Wildcard Expansion + * + * Demonstrates the WILDCARD feature where specifying wildcards causes + * the build system to expand to ALL valid configurations for the architecture. + * + * WILDCARD SYNTAX: + * - Integer params: ANY_INT or -1 (both are equivalent, ANY_INT is just a #define for -1) + * - String params: "*" (for pipeline, scheduler) + * + * The kernel declaration: + * .add(..., Algorithm().tile(64,64,64).wave(ANY_INT,ANY_INT,1).warp(-1,-1,-1) + * .pipeline("*").scheduler("*"), ...) + * + * Expands to multiple kernels: + * - wave: (1,4,1), (2,2,1), (4,1,1) -> 3 options + * - warp: (16,16,32), (32,32,16) -> 2 options + * - pipeline: "compv3" -> 1 option (compv4 requires special handling) + * - scheduler: "intrawave" -> 1 option + * + * Raw expansion: 3 × 2 = 6 configs, but arch filter validates each: + * - tile_m must be divisible by (warp_m × warp_tile_m) + * - tile_n must be divisible by (warp_n × warp_tile_n) + * - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16) + * Result: 4 valid wildcard kernels + 1 explicit = 5 total + * + * Build: cd dispatcher/build && cmake .. && make gemm_02_multi_size + * Usage: ./gemm_02_multi_size [--max-size N] [--help] + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: Demonstrates Wildcard Expansion +// ============================================================================= + +DECL_KERNEL_SET(multi_size_kernels, + // ------------------------------------------------------------------------- + // Kernel 1: Explicit - all parameters specified (no expansion) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + + // ------------------------------------------------------------------------- + // Kernel 2: WILDCARD - expands to multiple valid configurations + // Wildcards: ANY_INT == -1 (for integers), "*" (for strings) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 64) + .wave(ANY_INT, ANY_INT, 1) // ANY_INT → (1,4,1), (2,2,1), (4,1,1) + .warp(-1, -1, -1) // -1 same as ANY_INT → (16,16,32), (32,32,16) + .pipeline("*") // "*" → valid pipelines + .scheduler("*") // "*" → valid schedulers + .epilogue("cshuffle"), + "gfx942")); +// Raw: 3×2=6, arch filter removes 2 invalid → 4 valid kernels + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 02: Multi-Size GEMM with Wildcards", + "Demonstrates wildcard expansion for kernel generation"); + args.add_option("--max-size", "4096", "Maximum problem size to test"); + args.add_option("--arch", "gfx942", "GPU architecture"); + args.add_flag("--list", "List all registered kernels"); + args.add_flag("--list-verbose", "List kernels with full configuration details"); + + if(!args.parse(argc, argv)) + return 0; + + int max_size = args.get_int("--max-size", 4096); + std::string gfx_arch = args.get("--arch", "gfx942"); + + print_header("Example 02: Multi-Size GEMM with Wildcards"); + + // ========================================================================= + // Show Wildcard Expansion Concept + // ========================================================================= + std::cout << "\nWILDCARD EXPANSION:\n"; + std::cout << "===================\n"; + std::cout << R"( + Wildcard syntax: + ANY_INT or -1 -> expands integer params to all valid values + "*" -> expands string params (pipeline/scheduler) to valid values + + Declaration with wildcards: + .tile(64, 64, 64) -> fixed tile size (no wildcard) + .wave(ANY_INT, ANY_INT, 1) -> expands to (1,4,1), (2,2,1), (4,1,1) = 3 + .warp(-1, -1, -1) -> expands to (16,16,32), (32,32,16) = 2 + .pipeline("*") -> expands to valid pipelines = 1 + .scheduler("*") -> expands to valid schedulers = 1 + + Expanded: 3 × 2 = 6 configs, but arch filter validates each: + - wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64 + - Result: 4 valid kernels from wildcard + 1 explicit = 5 total +)"; + + // ========================================================================= + // Setup Registry and Dispatcher + // ========================================================================= + std::cout << "\nStep 1: Register Kernels\n"; + std::cout << "------------------------\n"; + + Registry registry; + registry.set_name("multi_size_registry"); + + // Register kernels from generated header (includes expanded wildcards) + // Use generic macro - no need to hardcode example name + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s) from wildcard expansion\n"; + + if(args.has("--list") || args.has("--list-verbose")) + { + std::cout << "\n"; + print_registered_kernels(registry, std::cout, args.has("--list-verbose")); + return 0; + } + + Dispatcher dispatcher(®istry); + std::cout << " Max size: " << max_size << "\n"; + + // ========================================================================= + // Run Multiple Problem Sizes + // ========================================================================= + std::cout << "\nStep 2: Run Multiple Sizes\n"; + print_separator(); + std::cout << std::setw(12) << "M" << std::setw(12) << "N" << std::setw(12) << "K" + << std::setw(12) << "Time(ms)" << std::setw(12) << "TFLOPS" << "\n"; + print_separator(); + + std::vector> all_sizes = { + {256, 256, 256}, + {512, 512, 512}, + {1024, 1024, 1024}, + {2048, 2048, 2048}, + {4096, 4096, 4096}, + }; + + std::vector> sizes; + for(const auto& [M, N, K] : all_sizes) + { + if(std::max({M, N, K}) <= max_size) + sizes.push_back({M, N, K}); + } + + using DataType = ck_tile::fp16_t; + bool all_passed = true; + + for(const auto& [M, N, K] : sizes) + { + Problem problem(M, N, K); + + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + double tflops = calculate_tflops(M, N, K, time_ms); + + std::cout << std::setw(12) << M << std::setw(12) << N << std::setw(12) << K << std::setw(12) + << std::fixed << std::setprecision(4) << time_ms << std::setw(12) + << std::setprecision(2) << tflops << "\n"; + + // Verify + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < M * N; ++i) + { + if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f) + ++errors; + } + if(errors > 0) + all_passed = false; + } + + print_separator(); + std::cout << "Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n"; + print_separator(); + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp b/dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp new file mode 100644 index 0000000000..61608c7914 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp @@ -0,0 +1,344 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 03: GEMM Benchmark & Validation + * + * Combined example demonstrating: + * 1. Benchmarking with statistics (warmup, iterations, min/max/mean/median) + * 2. Validation against CK Tile reference (CPU or GPU) + * + * Build: cd dispatcher/build && cmake .. && make gemm_03_benchmark_validation + * Usage: ./gemm_03_benchmark_validation [--size N] [--verify MODE] [--benchmark] + * + * Options: + * --size N Problem size MxNxK (default: 512) + * --verify MODE 0=none, 1=CPU ref, 2=GPU ref (default: 1) + * --benchmark Run full benchmark with statistics + * --warmup N Warmup iterations (default: 5) + * --iterations N Benchmark iterations (default: 20) + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/reference/reference_gemm.hpp" + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using namespace ck_tile::literals; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: High-performance kernels for benchmarking/validation +// ============================================================================= + +DECL_KERNEL_SET(benchmark_validation_kernels, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// Helper: Layout detection +// ============================================================================= + +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 03: GEMM Benchmark & Validation", + "Benchmark and/or validate GEMM output against reference"); + args.add_option("--size", "512", "Problem size MxNxK"); + args.add_option("--verify", "1", "Verification: 0=none, 1=CPU ref, 2=GPU ref"); + args.add_flag("--benchmark", "Run benchmark with statistics"); + args.add_option("--warmup", "5", "Warmup iterations"); + args.add_option("--iterations", "20", "Benchmark iterations"); + args.add_option("--rtol", "0.01", "Relative tolerance"); + args.add_option("--atol", "0.01", "Absolute tolerance"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + int M = args.get_int("--size", 512); + int N = M; + int K = M; + int verify = args.get_int("--verify", 1); + bool do_benchmark = args.has("--benchmark"); + int warmup = args.get_int("--warmup", 5); + int iterations = args.get_int("--iterations", 20); + float rtol = args.get_float("--rtol", 0.01f); + float atol = args.get_float("--atol", 0.01f); + std::string gfx_arch = args.get("--arch", "gfx942"); + + print_header("Example 03: GEMM Benchmark & Validation"); + + std::cout << "\nConfiguration:\n"; + std::cout << " Problem: " << M << " x " << N << " x " << K << "\n"; + std::cout << " Layout: RCR (A=row, B=col, C=row)\n"; + std::cout << " Verify: " << verify; + if(verify == 0) + std::cout << " (disabled)"; + else if(verify == 1) + std::cout << " (CPU reference)"; + else if(verify == 2) + std::cout << " (GPU reference)"; + std::cout << "\n"; + std::cout << " Benchmark: " << (do_benchmark ? "yes" : "no") << "\n"; + if(do_benchmark) + { + std::cout << " Warmup: " << warmup << " iterations\n"; + std::cout << " Measure: " << iterations << " iterations\n"; + } + + // ========================================================================= + // Setup Registry and Dispatcher + // ========================================================================= + Registry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + Dispatcher dispatcher(®istry); + + std::cout << " Kernels: " << registry.size() << " registered\n"; + print_registered_kernels(registry); + + // ========================================================================= + // Initialize data with proper tensor descriptors + // ========================================================================= + using ALayout = ck_tile::tensor_layout::gemm::RowMajor; + using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayout = ck_tile::tensor_layout::gemm::RowMajor; + + using ADataType = ck_tile::fp16_t; + using BDataType = ck_tile::fp16_t; + using CDataType = ck_tile::fp16_t; + using AccDataType = float; + + auto stride_a = ck_tile::get_default_stride(M, K, 0_uz, is_row_major(ALayout{})); + auto stride_b = ck_tile::get_default_stride(K, N, 0_uz, is_row_major(BLayout{})); + auto stride_c = ck_tile::get_default_stride(M, N, 0_uz, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_a, is_row_major(ALayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_b, is_row_major(BLayout{}))); + ck_tile::HostTensor c_m_n_dev( + ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{}))); + ck_tile::HostTensor c_m_n_ref( + ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{}))); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(b_k_n); + + std::cout << "\nData:\n"; + std::cout << " A: " << M << " x " << K << " (fp16, row-major)\n"; + std::cout << " B: " << K << " x " << N << " (fp16, col-major)\n"; + std::cout << " C: " << M << " x " << N << " (fp16, row-major)\n"; + + // GPU memory + ck_tile::DeviceMem a_dev(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_dev(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev(c_m_n_dev.get_element_space_size_in_bytes()); + + a_dev.ToDevice(a_m_k.data()); + b_dev.ToDevice(b_k_n.data()); + + // ========================================================================= + // Compute Reference (if needed) + // ========================================================================= + if(verify > 0) + { + std::cout << "\nComputing reference...\n"; + c_m_n_ref.SetZero(); + + if(verify == 1) + { + std::cout << " Using CPU reference (ck_tile::reference_gemm)\n"; + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_ref); + } + else if(verify == 2) + { + std::cout << " Using GPU reference (ck_tile::reference_gemm_gpu)\n"; + ck_tile::DeviceMem c_ref_dev(c_m_n_ref.get_element_space_size_in_bytes()); + c_ref_dev.SetZero(); + + ck_tile::reference_gemm_gpu( + static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_ref_dev.GetDeviceBuffer()), + M, + N, + K, + stride_a, + stride_b, + stride_c); + + (void)hipDeviceSynchronize(); + c_ref_dev.FromDevice(c_m_n_ref.data()); + } + std::cout << " Reference complete.\n"; + } + + // ========================================================================= + // Run Kernel + // ========================================================================= + Problem problem(M, N, K); + auto selected = dispatcher.select_kernel(problem); + + std::cout << "\nRunning kernel:\n"; + if(selected) + std::cout << " Selected: " << selected->get_name() << "\n"; + + c_dev.SetZero(); + float time_ms = 0.0f; + std::vector times; + + if(do_benchmark) + { + // Warmup + std::cout << " Warming up (" << warmup << " iterations)...\n"; + for(int i = 0; i < warmup; ++i) + { + c_dev.SetZero(); + (void)dispatcher.run(static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_dev.GetDeviceBuffer()), + problem, + nullptr); + } + + // Benchmark + std::cout << " Benchmarking (" << iterations << " iterations)...\n"; + times.reserve(iterations); + for(int i = 0; i < iterations; ++i) + { + c_dev.SetZero(); + float t = dispatcher.run(static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_dev.GetDeviceBuffer()), + problem, + nullptr); + times.push_back(t); + } + time_ms = *std::min_element(times.begin(), times.end()); + } + else + { + // Single run + time_ms = dispatcher.run(static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_dev.GetDeviceBuffer()), + problem, + nullptr); + } + + c_dev.FromDevice(c_m_n_dev.data()); + + // ========================================================================= + // Results + // ========================================================================= + double flops = 2.0 * M * N * K; + double tflops = flops / (time_ms * 1e9); + + print_separator(); + std::cout << "Performance:\n"; + print_separator(); + + if(do_benchmark && !times.empty()) + { + std::sort(times.begin(), times.end()); + float min_t = times.front(); + float max_t = times.back(); + float median_t = times[times.size() / 2]; + float mean_t = std::accumulate(times.begin(), times.end(), 0.0f) / times.size(); + + std::cout << std::fixed << std::setprecision(4); + std::cout << " Min: " << min_t << " ms (" << std::setprecision(2) + << (flops / (min_t * 1e9)) << " TFLOPS)\n"; + std::cout << std::setprecision(4); + std::cout << " Max: " << max_t << " ms\n"; + std::cout << " Mean: " << mean_t << " ms (" << std::setprecision(2) + << (flops / (mean_t * 1e9)) << " TFLOPS)\n"; + std::cout << std::setprecision(4); + std::cout << " Median: " << median_t << " ms (" << std::setprecision(2) + << (flops / (median_t * 1e9)) << " TFLOPS)\n"; + } + else + { + std::cout << std::fixed << std::setprecision(4); + std::cout << " Time: " << time_ms << " ms\n"; + std::cout << std::setprecision(2); + std::cout << " TFLOPS: " << tflops << "\n"; + } + + // ========================================================================= + // Validation + // ========================================================================= + bool pass = true; + + if(verify > 0) + { + print_separator(); + std::cout << "Validation:\n"; + print_separator(); + std::cout << " Tolerance: rtol=" << rtol << ", atol=" << atol << "\n"; + + pass = ck_tile::check_err(c_m_n_dev, c_m_n_ref, "Validation Error!", rtol, atol); + + float max_abs_diff = 0.0f; + float max_rel_diff = 0.0f; + for(size_t i = 0; i < c_m_n_dev.get_element_space_size(); ++i) + { + float dev_val = static_cast(c_m_n_dev.mData[i]); + float ref_val = static_cast(c_m_n_ref.mData[i]); + float abs_diff = std::abs(dev_val - ref_val); + float rel_diff = (ref_val != 0.0f) ? abs_diff / std::abs(ref_val) : abs_diff; + max_abs_diff = std::max(max_abs_diff, abs_diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + } + + std::cout << " Max abs diff: " << max_abs_diff << "\n"; + std::cout << " Max rel diff: " << max_rel_diff << "\n"; + } + + // ========================================================================= + // Summary + // ========================================================================= + print_separator(); + std::cout << "Result: " << (pass ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return pass ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/04_heuristics.cpp b/dispatcher/examples/gemm/cpp/04_heuristics.cpp new file mode 100644 index 0000000000..2a8753cdff --- /dev/null +++ b/dispatcher/examples/gemm/cpp/04_heuristics.cpp @@ -0,0 +1,168 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 04: Custom Heuristics + * + * Demonstrates custom kernel selection heuristics for different workloads. + * + * Build: cd dispatcher/build && cmake .. && make gemm_04_heuristics + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: Multiple tile sizes for heuristic-based selection +// ============================================================================= + +DECL_KERNEL_SET(heuristics_kernels, + // Small tile - low latency + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + // Medium tile - balanced + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// Custom Heuristic +// ============================================================================= + +std::vector size_based_heuristic(const Problem& problem) +{ + std::vector ranked_kernels; + int64_t total_elements = problem.M * problem.N; + + if(total_elements < 100000) + { + ranked_kernels = {"gemm_64x64", "gemm_128x128"}; + } + else + { + ranked_kernels = {"gemm_128x128", "gemm_64x64"}; + } + return ranked_kernels; +} + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 04: Custom Heuristics", + "Demonstrates custom kernel selection heuristics"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 04: Custom Heuristics"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + + // ========================================================================= + // Setup Registry and Dispatcher + // ========================================================================= + Registry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + Dispatcher dispatcher(®istry); + dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic); + dispatcher.set_heuristic(size_based_heuristic); + + std::cout << "\nSetup:\n"; + std::cout << " Registry: " << registry.size() << " kernel(s)\n"; + std::cout << " Strategy: Heuristic (size-based)\n"; + + // ========================================================================= + // Test Different Problem Sizes + // ========================================================================= + std::cout << "\nTesting heuristic selection:\n"; + print_separator(); + + using DataType = ck_tile::fp16_t; + + std::vector> sizes = { + {128, 128, 64}, + {512, 512, 256}, + {2048, 2048, 1024}, + }; + + bool all_passed = true; + + for(const auto& [M, N, K] : sizes) + { + Problem problem(M, N, K); + auto selected = dispatcher.select_kernel(problem); + + std::cout << "Problem " << M << "x" << N << "x" << K << ":\n"; + if(selected) + { + std::cout << " Selected: " << selected->get_name() << "\n"; + } + + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + double tflops = calculate_tflops(M, N, K, time_ms); + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Verify + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < M * N; ++i) + { + float actual = static_cast(c_host[i]); + if(std::abs(actual - expected) > 0.01f * expected + 1.0f) + ++errors; + } + bool pass = (errors == 0); + std::cout << " Verify: " << (pass ? "PASS" : "FAIL") << "\n"; + if(!pass) + all_passed = false; + print_separator(); + } + + std::cout << "Overall: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n"; + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/05_json_export.cpp b/dispatcher/examples/gemm/cpp/05_json_export.cpp new file mode 100644 index 0000000000..75ed7308af --- /dev/null +++ b/dispatcher/examples/gemm/cpp/05_json_export.cpp @@ -0,0 +1,127 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 05: JSON Export + * + * Demonstrates exporting registry information to JSON format. + * + * Build: cd dispatcher/build && cmake .. && make gemm_05_json_export + */ + +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: Multiple kernels for JSON export demo +// ============================================================================= + +DECL_KERNEL_SET(json_export_kernels, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 05: JSON Export", "Export registry information to JSON format"); + args.add_option("--output", "registry.json", "Output JSON file path"); + args.add_option("--arch", "gfx942", "GPU architecture"); + args.add_flag("--list", "List all kernel sets"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 05: JSON Export"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + + if(args.has("--list")) + { + std::cout << "\nDeclared Kernel Sets:\n"; + KernelSetRegistry::instance().print(); + return 0; + } + + std::string output_file = args.get("--output", "registry.json"); + + // ========================================================================= + // Setup Registry + // ========================================================================= + std::cout << "\nSetting up registry...\n"; + Registry registry; + registry.set_name("json_export_registry"); + + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + std::cout << " Registry: " << registry.get_name() << "\n"; + std::cout << " Kernels: " << registry.size() << "\n"; + + // ========================================================================= + // Export to JSON + // ========================================================================= + std::cout << "\nExporting to JSON...\n"; + + std::string json = registry.export_json(true); + + std::cout << "\nJSON Preview (first 500 chars):\n"; + print_separator(); + std::cout << json.substr(0, std::min(size_t(500), json.size())); + if(json.size() > 500) + std::cout << "\n..."; + std::cout << "\n"; + print_separator(); + + // Write to file + std::ofstream file(output_file); + if(file.is_open()) + { + file << json; + file.close(); + std::cout << "\nExported to: " << output_file << "\n"; + std::cout << "File size: " << json.size() << " bytes\n"; + } + else + { + std::cerr << "Failed to write to: " << output_file << "\n"; + return 1; + } + + // ========================================================================= + // Also show kernel set declarations + // ========================================================================= + std::cout << "\nKernel Set Declarations:\n"; + print_separator(); + KernelSetRegistry::instance().print(); + print_separator(); + + return 0; +} diff --git a/dispatcher/examples/gemm/cpp/06_multi_registry.cpp b/dispatcher/examples/gemm/cpp/06_multi_registry.cpp new file mode 100644 index 0000000000..3077f2d754 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/06_multi_registry.cpp @@ -0,0 +1,294 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 06: Multiple Registries and Multiple Kernel Sets + * + * Demonstrates: + * - Multiple DECL_KERNEL_SET declarations (each with multiple kernels) + * - Separate Registry instances for different workload types + * - Independent Dispatchers that select from their respective registries + * + * Registration patterns: + * - REGISTER_GENERATED_KERNELS(registry, arch) -> all kernels to one registry + * - REGISTER_KERNEL_SET("set_name", registry, arch) -> specific set by name + * - generated::get_kernel_set_names() -> list available set names + * + * Build: cd dispatcher/build && cmake .. && make gemm_06_multi_registry + * Usage: ./gemm_06_multi_registry [--list] [--help] + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SETS: Multiple sets with multiple kernels each +// ============================================================================= + +// Compute-bound kernel set: Large tiles for high arithmetic intensity +// Max tile with 32x32 warp is 128x128 (16 warps = 1024 threads) +DECL_KERNEL_SET(compute_bound_set, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) // Large tile, max for 32x32 warp + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) // Same tile, different K for variety + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// Memory-bound kernel set: Smaller tiles for better cache efficiency +DECL_KERNEL_SET(memory_bound_set, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// Latency-optimized: Minimal overhead tiles +DECL_KERNEL_SET(latency_set, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 06: Multiple Registries", + "Separate registries for different workload types"); + args.add_flag("--list", "List all declared kernel sets"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 06: Multiple Registries & Kernel Sets"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + + // ========================================================================= + // Step 1: Show declared kernel sets (from DECL_KERNEL_SET macros) + // ========================================================================= + std::cout << "\nStep 1: Declared Kernel Sets\n"; + std::cout << "-----------------------------\n"; + KernelSetRegistry::instance().print(); + + if(args.has("--list")) + { + // Print detailed info + for(const auto& name : KernelSetRegistry::instance().names()) + { + const auto& set = KernelSetRegistry::instance().get(name); + std::cout << "\n " << name << ":\n"; + for(const auto& decl : set.declarations()) + { + std::cout << " - " << decl.name() << " (tile=" << decl.algorithm.tile_m_ << "x" + << decl.algorithm.tile_n_ << "x" << decl.algorithm.tile_k_ << ")\n"; + } + } + return 0; + } + + // ========================================================================= + // Step 2: Create registries and demonstrate MERGING + // ========================================================================= + std::cout << "\nStep 2: Create and Merge Registries\n"; + std::cout << "------------------------------------\n"; + + // Create individual registries first + Registry compute_registry; + Registry latency_registry; + Registry memory_registry; + + compute_registry.set_name("compute_bound"); + latency_registry.set_name("latency_optimized"); + memory_registry.set_name("memory_bound"); + + // Register kernels to individual registries using set names (no hardcoding) + REGISTER_KERNEL_SET("compute_bound_set", compute_registry, gfx_arch); + REGISTER_KERNEL_SET("latency_set", latency_registry, gfx_arch); + REGISTER_KERNEL_SET("memory_bound_set", memory_registry, gfx_arch); + + std::cout << " Individual registries:\n"; + std::cout << " compute_bound: " << compute_registry.size() << " kernel(s)\n"; + std::cout << " latency_optimized: " << latency_registry.size() << " kernel(s)\n"; + std::cout << " memory_bound: " << memory_registry.size() << " kernel(s)\n"; + + // MERGE compute + latency into a combined registry + Registry combined_registry; + combined_registry.set_name("compute_latency_combined"); + + // Register both sets into combined registry + REGISTER_KERNEL_SET("compute_bound_set", combined_registry, gfx_arch); + REGISTER_KERNEL_SET("latency_set", combined_registry, gfx_arch); + + std::cout << "\n After merging compute + latency:\n"; + std::cout << " combined: " << combined_registry.size() << " kernel(s)\n"; + std::cout << " memory (separate): " << memory_registry.size() << " kernel(s)\n"; + + // ========================================================================= + // Step 3: Create dispatchers - one merged, one separate + // ========================================================================= + std::cout << "\nStep 3: Create Dispatchers\n"; + std::cout << "--------------------------\n"; + + Dispatcher combined_dispatcher(&combined_registry); // compute + latency merged + Dispatcher memory_dispatcher(&memory_registry); // memory separate + + std::cout << " combined_dispatcher: compute + latency kernels (" << combined_registry.size() + << " kernels)\n"; + std::cout << " memory_dispatcher: memory-bound kernels (" << memory_registry.size() + << " kernels)\n"; + + // ========================================================================= + // Step 4: Run with different dispatchers + // ========================================================================= + std::cout << "\nStep 4: Run Workloads\n"; + print_separator(); + + using DataType = ck_tile::fp16_t; + + struct WorkloadTest + { + const char* name; + Dispatcher* dispatcher; + int M, N, K; + }; + + std::vector tests = { + {"Compute-bound (combined)", &combined_dispatcher, 4096, 4096, 4096}, + {"Memory-bound (separate)", &memory_dispatcher, 1024, 1024, 1024}, + {"Latency-opt (combined)", &combined_dispatcher, 512, 512, 512}, + }; + + bool all_passed = true; + + for(const auto& test : tests) + { + Problem problem(test.M, test.N, test.K); + + // Allocate and initialize + GpuBuffer a_dev(test.M * test.K); + GpuBuffer b_dev(test.K * test.N); + GpuBuffer c_dev(test.M * test.N); + + std::vector a_host(test.M * test.K, DataType(1.0f)); + std::vector b_host(test.K * test.N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + // Select kernel and run + auto selected = test.dispatcher->select_kernel(problem); + float time_ms = + test.dispatcher->run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + double tflops = calculate_tflops(test.M, test.N, test.K, time_ms); + + std::cout << test.name << " (" << test.M << "x" << test.N << "x" << test.K << "):\n"; + if(selected) + std::cout << " Selected: " << selected->get_name() << "\n"; + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Verify ALL elements + std::vector c_host(test.M * test.N); + c_dev.copy_to_host(c_host.data()); + const float expected = static_cast(test.K); + + int num_errors = 0; + float max_error = 0.0f; + for(int i = 0; i < test.M * test.N; ++i) + { + float actual = static_cast(c_host[i]); + float error = std::abs(actual - expected); + max_error = std::max(max_error, error); + // Allow 1% relative tolerance for FP16 accumulation + if(error > 0.01f * expected + 1.0f) + ++num_errors; + } + + bool test_passed = (num_errors == 0); + std::cout << " Verify: " << (test.M * test.N) << " elements, errors=" << num_errors + << "\n"; + std::cout << " Status: " << (test_passed ? "PASS" : "FAIL") << "\n\n"; + + if(!test_passed) + all_passed = false; + } + + // ========================================================================= + // Summary + // ========================================================================= + print_separator(); + std::cout << "Multi-Registry Pattern Summary:\n"; + print_separator(); + std::cout << R"( +// 1. Declare multiple kernel sets +DECL_KERNEL_SET(compute_bound_set, .add(...)); +DECL_KERNEL_SET(memory_bound_set, .add(...)); +DECL_KERNEL_SET(latency_set, .add(...)); + +// 2. Create registries and register by set NAME (no hardcoding!) +Registry combined_reg, memory_reg; +REGISTER_KERNEL_SET("compute_bound_set", combined_reg, arch); // Add compute +REGISTER_KERNEL_SET("latency_set", combined_reg, arch); // Merge latency +REGISTER_KERNEL_SET("memory_bound_set", memory_reg, arch); // Separate + +// 3. Create dispatchers from merged/separate registries +Dispatcher combined_disp(&combined_reg); // Has both compute + latency +Dispatcher memory_disp(&memory_reg); // Has only memory-bound + +// 4. Choose dispatcher based on workload +if (problem.is_memory_bound()) + memory_disp.run(...); +else + combined_disp.run(...); // Handles both compute & latency workloads +)"; + print_separator(); + std::cout << "Overall Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n"; + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/README.md b/dispatcher/examples/gemm/cpp/README.md new file mode 100644 index 0000000000..1d81a90a0e --- /dev/null +++ b/dispatcher/examples/gemm/cpp/README.md @@ -0,0 +1,229 @@ +# GEMM C++ Examples + +CK Tile Dispatcher C++ examples for GEMM (General Matrix Multiplication) operations. + +> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md) + +## Quick Start + +### Build and Run + +```bash +cd /path/to/composable_kernel/dispatcher +mkdir -p build && cd build + +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Build (kernels generated automatically by CMake) +make -j$(nproc) + +# Run examples +cd examples +./gemm_01_basic +./gemm_03_benchmark_validation +./gemm_04_heuristics +``` + +## Examples + +| Example | Description | Complexity | +|---------|-------------|------------| +| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ | +| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ | +| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ | +| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ | +| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ | +| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ | + +## Example Details + +### 01_basic_gemm.cpp - Basic GEMM +Demonstrates the declarative kernel API with three patterns: + +1. **Autofill Pattern** - Minimal specification, defaults filled automatically +2. **Autocorrect Pattern** - Invalid parameters corrected at build time +3. **Full Specification Pattern** - Complete kernel configuration + +```cpp +DECL_KERNEL_SET(basic_kernels, + // Pattern 1: Autofill - minimal specification + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm(), // Defaults filled by autofill + "gfx942" + ) + // Pattern 2: Full specification + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16) + .pipeline("compv4").scheduler("intrawave"), + "gfx942" + ) +); +``` + +**Features:** +- Uses generic `REGISTER_GENERATED_KERNELS` macro +- `print_registered_kernels()` utility for debugging +- Demonstrates autofill messages during build + +### 02_multi_size.cpp - Wildcard Expansion +Demonstrates automatic generation of multiple kernel configurations: + +```cpp +DECL_KERNEL_SET(multi_kernels, + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(*, *, 32) // Wildcard tile M and N + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave"), + "gfx942" + ) +); +``` + +**Wildcard Values:** +- `*`, `-1`, or `ANY_INT` expand to all valid configurations +- Architecture filter prunes invalid combinations automatically +- Example generates 5 valid kernels after arch filtering (from 7 expansions) + +### 03_benchmark_validation.cpp - Benchmark + Validation +Consolidated example combining performance benchmarking with correctness validation: + +```bash +# Benchmark only +./gemm_03_benchmark_validation --warmup 10 --iterations 100 + +# With CPU validation +./gemm_03_benchmark_validation --verify 1 --rtol 1e-3 --atol 1e-3 + +# With GPU reference validation (faster for large matrices) +./gemm_03_benchmark_validation --verify 2 +``` + +**Features:** +- Warmup iterations (discarded from timing) +- Benchmark iterations with statistics (min/max/mean/median) +- CPU reference validation using `ck_tile::reference_gemm` +- GPU reference validation using `ck_tile::reference_gemm_gpu` +- Configurable tolerances + +### 04_heuristics.cpp - Heuristic Selection +Demonstrates custom kernel selection based on problem characteristics: + +```cpp +// Problem size analysis +auto heuristic = [](const Problem& p) -> std::optional { + if (p.M() * p.N() < 256 * 256) { + return small_kernel_key; // Memory-bound heuristic + } else { + return large_kernel_key; // Compute-bound heuristic + } +}; + +dispatcher.set_heuristic(heuristic); +``` + +**Features:** +- Problem size analysis (small vs large matrices) +- Compute-bound vs memory-bound selection +- Custom heuristic function registration + +### 05_json_export.cpp - JSON Export +Exports registry information to JSON for external tool integration: + +```cpp +auto json = registry.to_json(); +std::ofstream file("kernels.json"); +file << json; +``` + +**Use Cases:** +- Kernel metadata serialization +- External analysis tools +- Configuration management + +### 06_multi_registry.cpp - Multiple Registries +Demonstrates using multiple registries with named kernel sets: + +```cpp +// Define separate kernel sets +DECL_KERNEL_SET(compute_optimized, ...); +DECL_KERNEL_SET(latency_optimized, ...); + +// Register to specific registries +Registry compute_registry, latency_registry; +REGISTER_KERNEL_SET(compute_optimized, compute_registry); +REGISTER_KERNEL_SET(latency_optimized, latency_registry); + +// Use appropriate registry based on workload +Dispatcher compute_dispatcher(compute_registry); +Dispatcher latency_dispatcher(latency_registry); +``` + +**Features:** +- Named kernel set registration with `REGISTER_KERNEL_SET` macro +- Separate registries for different optimization goals +- Dynamic kernel set selection by name + +## Benchmark Parameters (stream_config) + +CK Tile uses `stream_config` for benchmark control: + +```cpp +ck_tile::stream_config cfg{ + nullptr, // stream_id - HIP stream (nullptr = default) + true, // time_kernel - Enable timing + 1, // log_level - Verbosity (0=quiet, 1=normal) + 5, // cold_niters - Warmup iterations + 20, // nrepeat - Benchmark iterations + true, // is_gpu_timer - Use GPU events vs CPU chrono + false, // flush_cache - Flush L2 cache between iterations + 1 // rotating_count - Rotating buffers for cache simulation +}; +``` + +| Parameter | CLI Option | Default | Description | +|-----------|------------|---------|-------------| +| `cold_niters_` | `--warmup` | 5 | Warmup iterations | +| `nrepeat_` | `--iterations` | 100 | Benchmark iterations | +| `flush_cache_` | - | false | Flush L2 cache | +| `rotating_count_` | - | 1 | Rotating buffers | +| `is_gpu_timer_` | - | true | GPU timer vs CPU | + +## Declarative Kernel Pattern + +All examples use the declarative `DECL_KERNEL_SET` macro: + +```cpp +DECL_KERNEL_SET(my_kernels, + .add( + Signature() // WHAT: operation signature + .dtype("fp16") // Data type + .layout("rcr"), // Matrix layouts (A=row, B=col, C=row) + Algorithm() // HOW: implementation details + .tile(256, 256, 32) // Tile sizes (M, N, K) + .wave(2, 2, 1) // Wave configuration + .warp(32, 32, 16) // Warp tile sizes + .pipeline("compv4") // Pipeline type + .scheduler("intrawave"), // Scheduler type + "gfx942" // WHERE: target architecture + ) +); +``` + +**Key Macros:** +- `DECL_KERNEL_SET(name, ...)` - Declare a kernel set +- `REGISTER_GENERATED_KERNELS` - Register all kernels from this example +- `REGISTER_KERNEL_SET(name, registry)` - Register specific kernel set to a registry + +## Related Documentation + +- [Python GEMM Examples](../python/README.md) +- [Convolution Examples](../../conv/cpp/README.md) +- [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/gemm/python/01_basic_gemm.py b/dispatcher/examples/gemm/python/01_basic_gemm.py new file mode 100644 index 0000000000..93a78d24d1 --- /dev/null +++ b/dispatcher/examples/gemm/python/01_basic_gemm.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 01: Basic GEMM with Multiple Kernels + +Demonstrates: +1. Declaring multiple kernel configurations +2. Printing all registered kernels +3. Running each kernel and validating output +4. Comparing performance across kernels + +Complexity: ★★☆☆☆ + +Usage: + python3 01_basic_gemm.py + python3 01_basic_gemm.py --help + python3 01_basic_gemm.py --dtype bf16 + python3 01_basic_gemm.py --size 2048 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +@dataclass +class KernelSpec: + """Specification for a kernel configuration""" + + name: str + tile_m: int + tile_n: int + tile_k: int + pipeline: str = "compv3" + scheduler: str = "intrawave" + + +# Define multiple kernel configurations to test (50+ kernels) +KERNEL_SPECS = [ + # Small tiles - compv3 + KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"), + KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"), + # Small tiles - compv4 + KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"), + KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"), + # Medium tiles - compv3 + KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"), + KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"), + KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"), + # Medium tiles - compv4 + KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"), + KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"), + KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"), + # Rectangular tiles - compv3 + KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"), + KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"), + KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"), + KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"), + # Rectangular tiles - compv4 + KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"), + KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"), + KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"), + KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"), + # Large tiles - compv3 + KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"), + KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"), + KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"), + KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"), + KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"), + KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"), + # Large tiles - compv4 + KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"), + KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"), + KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"), + KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"), + KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"), + KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"), + # Interwave scheduler variants + KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"), + KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"), + KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"), + KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"), + # More tile_k variations - compv3 + KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"), + KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"), + KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"), + # More tile_k variations - compv4 + KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"), + KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"), + KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"), + # Additional rectangular + KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"), + KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"), + KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"), + KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"), + # Additional compv4 variants + KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"), + KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"), + KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"), + KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"), +] + + +def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + """Create a KernelConfig from a spec""" + # Adjust warp tiles based on tile size + if spec.tile_m <= 64: + warp_m, warp_n = 16, 16 + else: + warp_m, warp_n = 32, 32 + + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=spec.tile_m, + tile_n=spec.tile_n, + tile_k=spec.tile_k, + wave_m=2, + wave_n=2, + wave_k=1, + warp_m=warp_m, + warp_n=warp_n, + warp_k=16, + pipeline=spec.pipeline, + scheduler=spec.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +def print_kernel_table(specs: List[KernelSpec], dtype: str): + """Print a formatted table of kernel configurations""" + print("\n" + "=" * 70) + print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)") + print("=" * 70) + print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}") + print(" " + "-" * 68) + + for i, spec in enumerate(specs, 1): + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + print( + f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}" + ) + + print(" " + "-" * 68) + print(f" Data type: {dtype}") + + +def main(): + parser = argparse.ArgumentParser( + description="Basic GEMM Example with Multiple Kernels", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 01_basic_gemm.py # Default FP16 with 4 kernels + python3 01_basic_gemm.py --dtype bf16 # BF16 mode + python3 01_basic_gemm.py --size 2048 # Larger problem size + python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target architecture (default: gfx942)", + ) + parser.add_argument( + "--size", + type=int, + default=512, + help="Problem size MxNxK (default: 512)", + ) + parser.add_argument( + "--num-kernels", + type=int, + default=0, + help="Number of kernels to test (0 = all)", + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 70) + print("Example 01: Basic GEMM with Multiple Kernels") + print("=" * 70) + + # Select kernels to test + specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS + + # ========================================================================= + # Step 1: Print all kernel configurations + # ========================================================================= + print_kernel_table(specs, args.dtype) + + # ========================================================================= + # Step 2: Setup and test each kernel + # ========================================================================= + print("\n" + "=" * 70) + print(" RUNNING KERNELS") + print("=" * 70) + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + M, N, K = args.size, args.size, args.size + + results = [] + + print(f"\n Problem size: {M}x{N}x{K}\n") + print( + f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}" + ) + print(" " + "-" * 78) + + for i, spec in enumerate(specs, 1): + # Create unique test data per kernel + np.random.seed(42 + i * 1000) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Create config and setup dispatcher + config = create_kernel_config(spec, args.dtype, args.arch) + + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"kernel_{spec.name}", + verbose=False, + auto_rebuild=True, + ) + + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + + if not setup.success: + print( + f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((spec.name, False, 0, 0, 0)) + cleanup_gemm() + continue + + dispatcher = setup.dispatcher + + # Check if size is supported + if not dispatcher.is_supported(M, N, K): + print( + f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}" + ) + results.append((spec.name, False, 0, 0, 0)) + cleanup_gemm() + continue + + # Run GEMM + result = dispatcher.run(A, B, M, N, K) + + if not result.success: + print( + f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((spec.name, False, 0, 0, 0)) + cleanup_gemm() + continue + + # Validate against NumPy reference + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + max_err = np.max(np.abs(result.output - C_ref)) + + # Check if within tolerance + passed = max_err < 1e-2 + status = "PASS" if passed else "FAIL" + + print( + f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}" + ) + results.append((spec.name, passed, result.time_ms, result.tflops, max_err)) + + cleanup_gemm() + + # ========================================================================= + # Step 3: Summary + # ========================================================================= + print("\n" + "=" * 70) + print(" SUMMARY") + print("=" * 70) + + passed = sum(1 for r in results if r[1]) + failed = len(results) - passed + + print(f"\n Results: {passed}/{len(results)} kernels passed") + print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}") + + if results: + valid_results = [r for r in results if r[1]] + if valid_results: + best = max(valid_results, key=lambda x: x[3]) + print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)") + + if failed == 0: + print("\n *** ALL KERNELS PASSED ***") + else: + print(f"\n *** {failed} KERNELS FAILED ***") + + print("=" * 70) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/02_batch_gemm.py b/dispatcher/examples/gemm/python/02_batch_gemm.py new file mode 100644 index 0000000000..039aba2790 --- /dev/null +++ b/dispatcher/examples/gemm/python/02_batch_gemm.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 02: Batch GEMM + +Runs multiple GEMM operations with different sizes. + +Complexity: ★★☆☆☆ + +Usage: + python3 02_batch_gemm.py + python3 02_batch_gemm.py --help + python3 02_batch_gemm.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Batch GEMM Example - runs multiple sizes", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 02_batch_gemm.py # Default FP16 + python3 02_batch_gemm.py --dtype bf16 # BF16 GEMM + python3 02_batch_gemm.py --max-size 2048 # Limit max size + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--max-size", + type=int, + default=4096, + help="Maximum problem size (default: 4096)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 02: Batch GEMM") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + + # ========================================================================= + # Step 2: Run batch of different sizes + # ========================================================================= + print("\nStep 2: Run Batch") + + # Generate sizes up to max_size + all_sizes = [ + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + ] + sizes = [(m, n, k) for m, n, k in all_sizes if max(m, n, k) <= args.max_size] + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + print(f"\n {'Size':<20} | {'Time (ms)':>12} | {'TFLOPS':>10} | {'Status':>8}") + print(" " + "-" * 60) + + total_ops = 0 + total_time = 0 + + for M, N, K in sizes: + if not dispatcher.is_supported(M, N, K): + print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Skipped") + continue + + A = np.random.randn(M, K).astype(np_dtype) * 0.1 + B = np.random.randn(K, N).astype(np_dtype) * 0.1 + + result = dispatcher.run(A, B, M, N, K) + + if result.success: + total_ops += 2 * M * N * K + total_time += result.time_ms + print( + f" {M:>4}x{N:>4}x{K:<4} | {result.time_ms:>12.4f} | {result.tflops:>10.2f} | OK" + ) + else: + print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Error") + + print(" " + "-" * 60) + + if total_time > 0: + avg_tflops = (total_ops / 1e12) / (total_time / 1000) + print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS") + + # Cleanup + cleanup_gemm() + + print("\n" + "=" * 60) + print("Batch GEMM complete!") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/03_benchmark.py b/dispatcher/examples/gemm/python/03_benchmark.py new file mode 100644 index 0000000000..bec1b7e2fb --- /dev/null +++ b/dispatcher/examples/gemm/python/03_benchmark.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 03: Benchmark + +Performance benchmarking with compute-optimized kernel configuration. + +Complexity: ★★★☆☆ + +Usage: + python3 03_benchmark.py + python3 03_benchmark.py --help + python3 03_benchmark.py --size 4096 + python3 03_benchmark.py --dtype bf16 --iterations 20 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmark Example - performance testing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 03_benchmark.py # Default benchmark suite + python3 03_benchmark.py --size 4096 # Single size benchmark + python3 03_benchmark.py --dtype bf16 # BF16 benchmark + python3 03_benchmark.py --iterations 20 # More iterations + """, + ) + parser.add_argument( + "--dtype", + default="bf16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: bf16)", + ) + parser.add_argument( + "--size", + type=int, + default=0, + help="Single problem size MxNxK (default: run all sizes)", + ) + parser.add_argument( + "--warmup", type=int, default=3, help="Warmup iterations (default: 3)" + ) + parser.add_argument( + "--iterations", type=int, default=10, help="Benchmark iterations (default: 10)" + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 03: Benchmark") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher with compute-optimized config + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + pipeline="compv4", + scheduler="intrawave", + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + + # ========================================================================= + # Step 2: Benchmark + # ========================================================================= + print("\nStep 2: Benchmark") + + if args.size > 0: + sizes = [(args.size, args.size, args.size)] + else: + sizes = [ + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + (1024, 2048, 512), + (2048, 1024, 2048), + ] + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + print(f" Warmup: {args.warmup}, Iterations: {args.iterations}\n") + + print(f" {'Size':<20} | {'Min (ms)':>10} | {'Avg (ms)':>10} | {'TFLOPS':>10}") + print(" " + "-" * 60) + + all_tflops = [] + + for M, N, K in sizes: + if not dispatcher.is_supported(M, N, K): + continue + + A = np.random.randn(M, K).astype(np_dtype) * 0.1 + B = np.random.randn(K, N).astype(np_dtype) * 0.1 + + # Warmup + for _ in range(args.warmup): + dispatcher.run(A, B, M, N, K) + + # Benchmark + times = [] + for _ in range(args.iterations): + result = dispatcher.run(A, B, M, N, K) + if result.success: + times.append(result.time_ms) + + if times: + min_time = min(times) + avg_time = sum(times) / len(times) + tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12 + all_tflops.append(tflops) + print( + f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}" + ) + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + print("Summary") + print("=" * 60) + + if all_tflops: + print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS") + print(f" Peak: {max(all_tflops):.2f} TFLOPS") + + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/04_validation.py b/dispatcher/examples/gemm/python/04_validation.py new file mode 100644 index 0000000000..2fe54c53f7 --- /dev/null +++ b/dispatcher/examples/gemm/python/04_validation.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 04: Validation + +Validates GPU GEMM against NumPy reference. + +Complexity: ★★★☆☆ + +Usage: + python3 04_validation.py + python3 04_validation.py --help + python3 04_validation.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + Validator, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Validation Example - validates GPU results against NumPy", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 04_validation.py # Default FP16 validation + python3 04_validation.py --dtype bf16 # BF16 validation + python3 04_validation.py --rtol 1e-2 # Relaxed tolerance + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--rtol", type=float, default=1e-3, help="Relative tolerance (default: 1e-3)" + ) + parser.add_argument( + "--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)" + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 04: Validation") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + + # ========================================================================= + # Step 2: Run validation tests + # ========================================================================= + print("\nStep 2: Validation Tests") + + validator = Validator(rtol=args.rtol, atol=args.atol) + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + test_cases = [ + ("Identity", 128, 128, 128, "identity"), + ("Small", 256, 256, 256, "random"), + ("Medium", 512, 512, 512, "random"), + ("Large", 1024, 1024, 1024, "random"), + ("Non-square", 512, 1024, 256, "random"), + ] + + passed = 0 + failed = 0 + + print(f"\n {'Test':<15} | {'Size':<15} | {'Max Err':>10} | {'Status':>8}") + print(" " + "-" * 55) + + for name, M, N, K, pattern in test_cases: + if not dispatcher.is_supported(M, N, K): + print(f" {name:<15} | {M}x{N}x{K:<5} | {'N/A':>10} | Skipped") + continue + + np.random.seed(42) + if pattern == "identity": + A = np.eye(M, K, dtype=np_dtype) + B = np.eye(K, N, dtype=np_dtype) + else: + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + result = dispatcher.run(A, B, M, N, K) + if not result.success: + print(f" {name:<15} | {M}x{N}x{K:<5} | {'GPU Err':>10} | FAILED") + failed += 1 + continue + + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + is_valid, max_err, _ = validator.check(result.output, C_ref) + + if is_valid: + print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | PASSED") + passed += 1 + else: + print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED") + failed += 1 + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + total = passed + failed + print(f"Results: {passed}/{total} passed") + print(f"Settings: dtype={args.dtype}, rtol={args.rtol}, atol={args.atol}") + print("=" * 60) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/05_numpy_integration.py b/dispatcher/examples/gemm/python/05_numpy_integration.py new file mode 100644 index 0000000000..493ce46d22 --- /dev/null +++ b/dispatcher/examples/gemm/python/05_numpy_integration.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 05: NumPy Integration + +Shows how to create a GPU-accelerated matmul wrapper. + +Complexity: ★★☆☆☆ + +Usage: + python3 05_numpy_integration.py + python3 05_numpy_integration.py --help + python3 05_numpy_integration.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + Dispatcher, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +class GPUMatmul: + """GPU-accelerated matrix multiplication wrapper.""" + + def __init__(self, dispatcher: Dispatcher): + self.dispatcher = dispatcher + + def __call__(self, A: np.ndarray, B: np.ndarray) -> np.ndarray: + """Compute C = A @ B on GPU with CPU fallback.""" + M, K = A.shape + K2, N = B.shape + + if K != K2: + raise ValueError(f"Dimension mismatch: {A.shape} @ {B.shape}") + + if not self.dispatcher.is_supported(M, N, K): + return np.matmul(A, B) + + result = self.dispatcher.run(A, B, M, N, K) + return result.output if result.success else np.matmul(A, B) + + +def main(): + parser = argparse.ArgumentParser( + description="NumPy Integration Example - GPU-accelerated matmul wrapper", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 05_numpy_integration.py # Default FP16 + python3 05_numpy_integration.py --dtype bf16 # BF16 mode + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 05: NumPy Integration") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="numpy", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 2: Create GPU matmul wrapper + # ========================================================================= + print("\nStep 2: Create GPUMatmul") + + gpu_matmul = GPUMatmul(dispatcher=dispatcher) + print(" gpu_matmul ready") + + # ========================================================================= + # Step 3: Demo - Simple multiplication using gpu_matmul + # ========================================================================= + print("\nStep 3: Demo - Simple Multiplication") + + A = np.random.randn(1024, 512).astype(np_dtype) * 0.1 + B = np.random.randn(512, 256).astype(np_dtype) * 0.1 + + # Use the gpu_matmul wrapper + C = gpu_matmul(A, B) + print(f" gpu_matmul result: {C.shape}, sum={C.sum():.4f}") + + M, K = A.shape + _, N = B.shape + result = dispatcher.run(A, B, M, N, K) + + print(f" A: {A.shape}, B: {B.shape} -> C: {result.output.shape}") + print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS") + + # ========================================================================= + # Step 4: Demo - FFN block + # ========================================================================= + print("\nStep 4: Demo - FFN Block") + + batch, hidden, ffn = 128, 768, 3072 + X = np.random.randn(batch, hidden).astype(np_dtype) * 0.02 + W1 = np.random.randn(hidden, ffn).astype(np_dtype) * 0.02 + W2 = np.random.randn(ffn, hidden).astype(np_dtype) * 0.02 + + result1 = dispatcher.run(X, W1, batch, ffn, hidden) + H = result1.output + result2 = dispatcher.run(H, W2, batch, hidden, ffn) + + print(f" X: {X.shape} -> H: {H.shape} -> Y: {result2.output.shape}") + print(f" Total: {result1.time_ms + result2.time_ms:.4f} ms") + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + print("NumPy Integration Pattern:") + print("=" * 60) + print(" 1. setup_gemm_dispatcher(config)") + print(" 2. GPUMatmul(dispatcher)") + print(" 3. C = gpu_matmul(A, B)") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/06_json_export.py b/dispatcher/examples/gemm/python/06_json_export.py new file mode 100644 index 0000000000..9e062e507b --- /dev/null +++ b/dispatcher/examples/gemm/python/06_json_export.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 06: JSON Export + +Exports registry configuration to JSON. + +Complexity: ★★☆☆☆ + +Usage: + python3 06_json_export.py + python3 06_json_export.py --help + python3 06_json_export.py --output my_kernels.json +""" + +import sys +import json +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="JSON Export Example - exports registry to JSON", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 06_json_export.py # Default output to kernels.json + python3 06_json_export.py --output my.json # Custom output file + """, + ) + parser.add_argument( + "--output", + "-o", + default="kernels.json", + help="Output JSON file (default: kernels.json)", + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 06: JSON Export") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="export_demo", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + # ========================================================================= + # Step 2: Define additional configs for export + # ========================================================================= + print("\nStep 2: Define Additional Configs") + + configs = [ + config, + KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=256, + tile_n=256, + tile_k=64, + gfx_arch=args.arch, + ), + KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=64, + tile_n=64, + tile_k=32, + gfx_arch=args.arch, + ), + ] + + for cfg in configs: + print(f" - {cfg.tile_str}") + + # ========================================================================= + # Step 3: Export to JSON + # ========================================================================= + print("\nStep 3: Export to JSON") + + export_data = { + "registry": setup.registry.name, + "kernel_count": len(configs), + "kernels": [], + } + + for cfg in configs: + kernel_info = { + "tile": cfg.tile_str, + "dtypes": {"A": cfg.dtype_a, "B": cfg.dtype_b, "C": cfg.dtype_c}, + "layout": cfg.layout, + "pipeline": cfg.pipeline, + "target": cfg.gfx_arch, + } + export_data["kernels"].append(kernel_info) + + # Include C++ library info + if setup.lib: + cpp_json = setup.lib.export_registry_json() + try: + export_data["cpp_registry"] = json.loads(cpp_json) + except json.JSONDecodeError: + pass + + json_str = json.dumps(export_data, indent=2) + + with open(args.output, "w") as f: + f.write(json_str) + print(f" Saved to: {args.output}") + + # Preview + print("\nStep 4: Preview") + print("-" * 60) + print(json_str[:500] + ("..." if len(json_str) > 500 else "")) + print("-" * 60) + + # Cleanup + cleanup_gemm() + + print("\n" + "=" * 60) + print("JSON Export complete!") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/07_stress_test.py b/dispatcher/examples/gemm/python/07_stress_test.py new file mode 100644 index 0000000000..8160030631 --- /dev/null +++ b/dispatcher/examples/gemm/python/07_stress_test.py @@ -0,0 +1,513 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 07: Stress Test - Multiple Kernels with Validation + +Consolidated stress test that: +1. Declares multiple kernel configurations (various tiles, pipelines, layouts) +2. Prints all registered kernels with details +3. Validates each kernel against NumPy reference +4. Optional benchmarking mode + +This tests: +- Multiple tile sizes (64x64, 128x128, 256x256) +- Multiple pipelines (compv3, compv4) +- Multiple data types (fp16, bf16) +- Different schedulers (intrawave, interwave) + +Complexity: ★★★★☆ + +Usage: + python3 07_stress_test.py + python3 07_stress_test.py --help + python3 07_stress_test.py --num-kernels 10 + python3 07_stress_test.py --benchmark + python3 07_stress_test.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List, Tuple + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, + Validator, +) + + +@dataclass +class KernelSpec: + """A kernel specification for testing""" + + name: str + tile_m: int + tile_n: int + tile_k: int + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + warp_m: int = 32 + warp_n: int = 32 + warp_k: int = 16 + pipeline: str = "compv3" + scheduler: str = "intrawave" + layout: str = "rcr" + + def to_config(self, dtype: str, arch: str) -> KernelConfig: + """Convert to KernelConfig""" + # Adjust warp tiles for smaller tiles + warp_m = min(self.warp_m, self.tile_m // self.wave_m) + warp_n = min(self.warp_n, self.tile_n // self.wave_n) + warp_k = self.warp_k + + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a={"r": "row", "c": "col"}[self.layout[0]], + layout_b={"r": "row", "c": "col"}[self.layout[1]], + layout_c={"r": "row", "c": "col"}[self.layout[2]], + tile_m=self.tile_m, + tile_n=self.tile_n, + tile_k=self.tile_k, + wave_m=self.wave_m, + wave_n=self.wave_n, + wave_k=self.wave_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + pipeline=self.pipeline, + scheduler=self.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +# Define stress test kernel configurations +KERNEL_SPECS = [ + # Small tiles - compv3 + KernelSpec( + "small_compv3", + 64, + 64, + 32, + wave_m=2, + wave_n=2, + warp_m=16, + warp_n=16, + warp_k=32, + pipeline="compv3", + ), + KernelSpec( + "small_compv4", + 64, + 64, + 32, + wave_m=2, + wave_n=2, + warp_m=16, + warp_n=16, + warp_k=32, + pipeline="compv4", + ), + # Medium tiles + KernelSpec( + "medium_compv3", + 128, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + KernelSpec( + "medium_compv4", + 128, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv4", + ), + KernelSpec( + "medium_k64", + 128, + 128, + 64, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + # Rectangular tiles + KernelSpec( + "rect_64x128", + 64, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + KernelSpec( + "rect_128x64", + 128, + 64, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + # Different schedulers + KernelSpec( + "interwave", + 128, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + scheduler="interwave", + ), + # Large tiles + KernelSpec( + "large_compv3", + 256, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + KernelSpec( + "large_compv4", + 256, + 128, + 64, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv4", + ), +] + + +def print_kernel_summary(specs: List[KernelSpec], dtype: str): + """Print a summary table of all kernel specs""" + print("\n" + "=" * 80) + print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)") + print("=" * 80) + print( + f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Wave':<10} {'Warp':<12} {'Pipeline':<10} {'Sched':<10}" + ) + print(" " + "-" * 78) + + for i, spec in enumerate(specs, 1): + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + wave = f"{spec.wave_m}x{spec.wave_n}x{spec.wave_k}" + warp = f"{spec.warp_m}x{spec.warp_n}x{spec.warp_k}" + print( + f" {i:<3} {spec.name:<18} {tile:<12} {wave:<10} {warp:<12} {spec.pipeline:<10} {spec.scheduler:<10}" + ) + + print(" " + "-" * 78) + print(f" Data type: {dtype}\n") + + +def validate_kernel( + spec: KernelSpec, + dtype: str, + arch: str, + size: int, + validator: Validator, + kernel_index: int = 0, + verbose: bool = False, +) -> Tuple[bool, float, str]: + """ + Validate a single kernel configuration. + Returns: (passed, max_error, message) + """ + np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32 + + # Create config + config = spec.to_config(dtype, arch) + + # Setup dispatcher + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"stress_{spec.name}", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + return False, 0.0, f"Setup failed: {setup.error}" + + dispatcher = setup.dispatcher + M, N, K = size, size, size + + if not dispatcher.is_supported(M, N, K): + cleanup_gemm() + return False, 0.0, f"Size {M}x{N}x{K} not supported" + + # Use different seed per kernel to get unique test data + # This ensures each kernel is tested with different matrices + np.random.seed(42 + kernel_index * 1000) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Run GPU GEMM + result = dispatcher.run(A, B, M, N, K) + + if not result.success: + cleanup_gemm() + return False, 0.0, "GPU execution failed" + + # Validate against NumPy + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + is_valid, max_err, _ = validator.check(result.output, C_ref) + + cleanup_gemm() + + return is_valid, max_err, f"{result.time_ms:.2f}ms, {result.tflops:.1f} TFLOPS" + + +def benchmark_kernel( + spec: KernelSpec, + dtype: str, + arch: str, + size: int, + warmup: int = 3, + iterations: int = 10, +) -> Tuple[bool, float, float]: + """ + Benchmark a kernel configuration. + Returns: (success, avg_time_ms, tflops) + """ + np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32 + + config = spec.to_config(dtype, arch) + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"bench_{spec.name}", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + return False, 0.0, 0.0 + + dispatcher = setup.dispatcher + M, N, K = size, size, size + + if not dispatcher.is_supported(M, N, K): + cleanup_gemm() + return False, 0.0, 0.0 + + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Warmup + for _ in range(warmup): + dispatcher.run(A, B, M, N, K) + + # Benchmark + times = [] + for _ in range(iterations): + result = dispatcher.run(A, B, M, N, K) + if result.success: + times.append(result.time_ms) + + cleanup_gemm() + + if not times: + return False, 0.0, 0.0 + + avg_time = sum(times) / len(times) + tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12 + + return True, avg_time, tflops + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Stress Test - Multiple kernels with validation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 07_stress_test.py # Test all kernels + python3 07_stress_test.py --num-kernels 5 # Test first 5 kernels + python3 07_stress_test.py --benchmark # Include benchmarks + python3 07_stress_test.py --dtype bf16 # Test BF16 + python3 07_stress_test.py --size 2048 # Use 2048x2048 matrices + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--num-kernels", + type=int, + default=0, + help="Number of kernels to test (0 = all)", + ) + parser.add_argument( + "--size", + type=int, + default=512, + help="Problem size MxNxK (default: 512)", + ) + parser.add_argument( + "--benchmark", + action="store_true", + help="Include benchmark timing", + ) + parser.add_argument( + "--rtol", + type=float, + default=1e-2, + help="Relative tolerance (default: 1e-2)", + ) + parser.add_argument( + "--atol", + type=float, + default=1e-2, + help="Absolute tolerance (default: 1e-2)", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target architecture (default: gfx942)", + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 80) + print("Example 07: GEMM Stress Test - Multiple Kernels") + print("=" * 80) + + # Select kernels to test + specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS + + # Print kernel summary + print_kernel_summary(specs, args.dtype) + + # Run validation + print("\n" + "=" * 80) + print(" VALIDATION RESULTS") + print("=" * 80) + + validator = Validator(rtol=args.rtol, atol=args.atol) + + if args.benchmark: + print( + f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Time':>10} {'TFLOPS':>8} {'Status':<8}" + ) + else: + print( + f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Info':<25} {'Status':<8}" + ) + print(" " + "-" * 78) + + passed = 0 + failed = 0 + skipped = 0 + + for i, spec in enumerate(specs, 1): + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + + try: + is_valid, max_err, info = validate_kernel( + spec, args.dtype, args.arch, args.size, validator, kernel_index=i + ) + + if is_valid: + status = "PASS" + passed += 1 + else: + status = "FAIL" + failed += 1 + + if args.benchmark: + success, avg_time, tflops = benchmark_kernel( + spec, args.dtype, args.arch, args.size + ) + if success: + print( + f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {avg_time:>9.2f}ms {tflops:>7.1f} {status:<8}" + ) + else: + print( + f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {'N/A':>10} {'N/A':>8} {status:<8}" + ) + else: + print( + f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {info:<25} {status:<8}" + ) + + except Exception as e: + skipped += 1 + print( + f" {i:<3} {spec.name:<18} {tile:<12} {'N/A':>10} {str(e)[:25]:<25} {'SKIP':<8}" + ) + + # Summary + print("\n" + "=" * 80) + print(" SUMMARY") + print("=" * 80) + total = passed + failed + skipped + print(f"\n Results: {passed}/{total} passed, {failed} failed, {skipped} skipped") + print(f" Settings: dtype={args.dtype}, size={args.size}x{args.size}x{args.size}") + print(f" Tolerance: rtol={args.rtol}, atol={args.atol}") + print(f" Architecture: {args.arch}") + + if failed == 0 and skipped == 0: + print("\n *** ALL KERNELS PASSED ***") + elif failed > 0: + print(f"\n *** {failed} KERNELS FAILED ***") + + print("=" * 80) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/08_heuristics.py b/dispatcher/examples/gemm/python/08_heuristics.py new file mode 100644 index 0000000000..e2763c0513 --- /dev/null +++ b/dispatcher/examples/gemm/python/08_heuristics.py @@ -0,0 +1,718 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 08: Custom Heuristics + +Demonstrates custom kernel selection heuristics based on problem characteristics. + +This example shows how to: +1. Define multiple kernel configurations for different workloads +2. Implement custom heuristics to select the best kernel +3. Test heuristic selection across different problem sizes + +Heuristic strategies: +- Size-based: Small tiles for small problems, large tiles for large problems +- Compute-bound: Maximize compute utilization for large matrices +- Memory-bound: Optimize memory access for bandwidth-limited cases +- Latency-focused: Minimize kernel launch overhead for small problems + +Complexity: ★★★★☆ + +Usage: + python3 08_heuristics.py + python3 08_heuristics.py --help + python3 08_heuristics.py --strategy compute + python3 08_heuristics.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List +from enum import Enum + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +# ============================================================================= +# Kernel Specifications +# ============================================================================= + + +@dataclass +class KernelSpec: + """Kernel specification with metadata for heuristic selection""" + + name: str + tile_m: int + tile_n: int + tile_k: int + pipeline: str = "compv3" + scheduler: str = "intrawave" + # Metadata for heuristics + category: str = "balanced" # small, balanced, large, compute, memory + min_problem_size: int = 0 + max_problem_size: int = float("inf") + + +# Define kernel pool for heuristic selection (20+ kernels) +KERNEL_POOL = [ + # ========================================================================== + # SMALL TILES - Low latency, good for small problems + # ========================================================================== + KernelSpec( + "small_64x64_k32", + 64, + 64, + 32, + "compv3", + "intrawave", + category="small", + max_problem_size=256 * 256, + ), + KernelSpec( + "small_64x64_k64", + 64, + 64, + 64, + "compv3", + "intrawave", + category="small", + max_problem_size=256 * 256, + ), + KernelSpec( + "small_64x64_v4", + 64, + 64, + 32, + "compv4", + "intrawave", + category="small", + max_problem_size=256 * 256, + ), + # ========================================================================== + # MEDIUM TILES - Balanced performance + # ========================================================================== + KernelSpec( + "medium_128x128_k32", + 128, + 128, + 32, + "compv3", + "intrawave", + category="balanced", + min_problem_size=128 * 128, + max_problem_size=2048 * 2048, + ), + KernelSpec( + "medium_128x128_k64", + 128, + 128, + 64, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "medium_128x128_k128", + 128, + 128, + 128, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "medium_128x128_v4_k32", + 128, + 128, + 32, + "compv4", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "medium_128x128_v4_k64", + 128, + 128, + 64, + "compv4", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + # Rectangular medium tiles + KernelSpec( + "rect_64x128_k32", + 64, + 128, + 32, + "compv3", + "intrawave", + category="balanced", + min_problem_size=128 * 128, + ), + KernelSpec( + "rect_128x64_k32", + 128, + 64, + 32, + "compv3", + "intrawave", + category="balanced", + min_problem_size=128 * 128, + ), + KernelSpec( + "rect_64x128_k64", + 64, + 128, + 64, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "rect_128x64_k64", + 128, + 64, + 64, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + # ========================================================================== + # LARGE TILES - High throughput for large problems + # ========================================================================== + KernelSpec( + "large_256x128_k32", + 256, + 128, + 32, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_256x128_k64", + 256, + 128, + 64, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_128x256_k32", + 128, + 256, + 32, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_128x256_k64", + 128, + 256, + 64, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_256x256_k32", + 256, + 256, + 32, + "compv3", + "intrawave", + category="large", + min_problem_size=1024 * 1024, + ), + KernelSpec( + "large_256x256_k64", + 256, + 256, + 64, + "compv3", + "intrawave", + category="large", + min_problem_size=1024 * 1024, + ), + # ========================================================================== + # COMPUTE-OPTIMIZED - compv4 pipeline for compute-bound workloads + # ========================================================================== + KernelSpec( + "compute_128x128_v4_k32", + 128, + 128, + 32, + "compv4", + "intrawave", + category="compute", + min_problem_size=256 * 256, + ), + KernelSpec( + "compute_128x128_v4_k64", + 128, + 128, + 64, + "compv4", + "intrawave", + category="compute", + min_problem_size=256 * 256, + ), + KernelSpec( + "compute_256x128_v4", + 256, + 128, + 64, + "compv4", + "intrawave", + category="compute", + min_problem_size=512 * 512, + ), + KernelSpec( + "compute_256x256_v4", + 256, + 256, + 64, + "compv4", + "intrawave", + category="compute", + min_problem_size=1024 * 1024, + ), + # ========================================================================== + # MEMORY-OPTIMIZED - Good cache utilization for memory-bound workloads + # ========================================================================== + KernelSpec( + "memory_128x128_k16", + 128, + 128, + 16, + "compv3", + "intrawave", + category="memory", + min_problem_size=256 * 256, + ), + KernelSpec( + "memory_64x128_k16", + 64, + 128, + 16, + "compv3", + "intrawave", + category="memory", + min_problem_size=128 * 128, + ), +] + + +def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + """Create KernelConfig from spec""" + warp_m = 16 if spec.tile_m <= 64 else 32 + warp_n = 16 if spec.tile_n <= 64 else 32 + + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=spec.tile_m, + tile_n=spec.tile_n, + tile_k=spec.tile_k, + wave_m=2, + wave_n=2, + wave_k=1, + warp_m=warp_m, + warp_n=warp_n, + warp_k=16, + pipeline=spec.pipeline, + scheduler=spec.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +# ============================================================================= +# Heuristic Strategies +# ============================================================================= + + +class HeuristicStrategy(Enum): + SIZE_BASED = "size" + COMPUTE_BOUND = "compute" + MEMORY_BOUND = "memory" + LATENCY_FOCUSED = "latency" + + +def size_based_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel based on problem size. + - Small problems: Use small tiles for low latency + - Medium problems: Use balanced tiles + - Large problems: Use large tiles for high throughput + + Also considers K dimension for tile_k selection. + """ + total_elements = M * N + + # Filter by problem size constraints + candidates = [ + k for k in kernels if k.min_problem_size <= total_elements <= k.max_problem_size + ] + + if not candidates: + candidates = kernels # Fall back to all kernels + + # Determine target category based on problem size + if total_elements < 256 * 256: + target_category = "small" + elif total_elements < 1024 * 1024: + target_category = "balanced" + else: + target_category = "large" + + # Filter by category if possible + category_candidates = [k for k in candidates if k.category == target_category] + if category_candidates: + candidates = category_candidates + + # Select best tile_k based on K dimension + # Prefer tile_k that divides K well + def tile_k_score(k): + if K % k.tile_k == 0: + return 0 # Perfect division + return K % k.tile_k # Remainder (lower is better) + + # Sort by tile_k fit, then by tile size + candidates.sort(key=lambda k: (tile_k_score(k), -k.tile_m * k.tile_n)) + + return candidates[0] + + +def compute_bound_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel optimized for compute-bound workloads. + Prefers compv4 pipeline and larger tiles. + Selects based on problem size to maximize compute utilization. + """ + total_elements = M * N + + # Prefer compute category kernels + compute_kernels = [k for k in kernels if k.category == "compute"] + + if not compute_kernels: + # Fall back to compv4 kernels + compute_kernels = [k for k in kernels if k.pipeline == "compv4"] + + if not compute_kernels: + compute_kernels = kernels + + # Filter by problem size + valid = [k for k in compute_kernels if k.min_problem_size <= total_elements] + if valid: + compute_kernels = valid + + # For large problems, prefer larger tiles + if total_elements >= 1024 * 1024: + return max(compute_kernels, key=lambda k: k.tile_m * k.tile_n * k.tile_k) + else: + # For smaller problems, prefer medium tiles + return min( + compute_kernels, key=lambda k: abs(k.tile_m - 128) + abs(k.tile_n - 128) + ) + + +def memory_bound_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel optimized for memory-bound workloads. + Prefers smaller tile_k for better memory access patterns. + """ + # Prefer memory category kernels first + memory_kernels = [k for k in kernels if k.category == "memory"] + if memory_kernels: + # Select based on problem size + total = M * N + if total < 512 * 512: + return min(memory_kernels, key=lambda k: k.tile_m * k.tile_n) + return max(memory_kernels, key=lambda k: k.tile_m * k.tile_n) + + # Fall back to balanced with smaller tile_k + balanced = [k for k in kernels if k.category == "balanced"] + if balanced: + # Prefer smaller tile_k for memory-bound + return min(balanced, key=lambda k: k.tile_k) + + # Fall back to medium-sized tile with small tile_k + return min( + kernels, key=lambda k: (k.tile_k, abs(k.tile_m - 128) + abs(k.tile_n - 128)) + ) + + +def latency_focused_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel optimized for low latency. + Prefers smaller tiles and compv4 for faster execution. + """ + # Prefer small category + small_kernels = [k for k in kernels if k.category == "small"] + + if small_kernels: + # Among small kernels, prefer compv4 for lower latency + v4_small = [k for k in small_kernels if k.pipeline == "compv4"] + if v4_small: + return v4_small[0] + return small_kernels[0] + + # Fall back to smallest tile with compv4 if available + all_v4 = [k for k in kernels if k.pipeline == "compv4"] + if all_v4: + return min(all_v4, key=lambda k: k.tile_m * k.tile_n) + + # Fall back to smallest tile + return min(kernels, key=lambda k: k.tile_m * k.tile_n) + + +HEURISTICS = { + HeuristicStrategy.SIZE_BASED: size_based_heuristic, + HeuristicStrategy.COMPUTE_BOUND: compute_bound_heuristic, + HeuristicStrategy.MEMORY_BOUND: memory_bound_heuristic, + HeuristicStrategy.LATENCY_FOCUSED: latency_focused_heuristic, +} + + +# ============================================================================= +# Main +# ============================================================================= + + +def print_kernel_pool(kernels: List[KernelSpec]): + """Print available kernels""" + print("\n" + "=" * 75) + print(" KERNEL POOL") + print("=" * 75) + print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Category':<12}") + print(" " + "-" * 73) + + for i, k in enumerate(kernels, 1): + tile = f"{k.tile_m}x{k.tile_n}x{k.tile_k}" + print(f" {i:<3} {k.name:<22} {tile:<14} {k.pipeline:<10} {k.category:<12}") + + print(" " + "-" * 73) + + +def main(): + parser = argparse.ArgumentParser( + description="Custom Heuristics Example - intelligent kernel selection", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 08_heuristics.py # Default size-based heuristic + python3 08_heuristics.py --strategy compute # Compute-bound heuristic + python3 08_heuristics.py --strategy memory # Memory-bound heuristic + python3 08_heuristics.py --strategy latency # Latency-focused heuristic + python3 08_heuristics.py --dtype bf16 # BF16 mode + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--strategy", + default="size", + choices=["size", "compute", "memory", "latency"], + help="Heuristic strategy (default: size)", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target architecture (default: gfx942)", + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 75) + print("Example 08: Custom Heuristics") + print("=" * 75) + + # Map strategy string to enum + strategy_map = { + "size": HeuristicStrategy.SIZE_BASED, + "compute": HeuristicStrategy.COMPUTE_BOUND, + "memory": HeuristicStrategy.MEMORY_BOUND, + "latency": HeuristicStrategy.LATENCY_FOCUSED, + } + strategy = strategy_map[args.strategy] + heuristic_fn = HEURISTICS[strategy] + + print(f"\n Strategy: {strategy.value}") + print(f" Data type: {args.dtype}") + + # Print kernel pool + print_kernel_pool(KERNEL_POOL) + + # ========================================================================= + # Test heuristic selection across different problem sizes + # ========================================================================= + print("\n" + "=" * 75) + print(" HEURISTIC SELECTION TEST") + print("=" * 75) + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + test_sizes = [ + (128, 128, 64), # Small + (256, 256, 128), # Small-medium + (512, 512, 256), # Medium + (1024, 1024, 512), # Medium-large + (2048, 2048, 1024), # Large + ] + + print( + f"\n {'Size':<20} {'Selected Kernel':<25} {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}" + ) + print(" " + "-" * 78) + + results = [] + + for M, N, K in test_sizes: + # Use heuristic to select kernel + selected_spec = heuristic_fn(M, N, K, KERNEL_POOL) + + # Create config and setup + config = create_kernel_config(selected_spec, args.dtype, args.arch) + + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"heuristic_{selected_spec.name}", + verbose=False, + auto_rebuild=True, + ) + + size_str = f"{M}x{N}x{K}" + + if not setup.success: + print( + f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((size_str, selected_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + dispatcher = setup.dispatcher + + if not dispatcher.is_supported(M, N, K): + print( + f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'SKIP':<8}" + ) + results.append((size_str, selected_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + # Run GEMM + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + result = dispatcher.run(A, B, M, N, K) + + if not result.success: + print( + f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((size_str, selected_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + # Validate + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + max_err = np.max(np.abs(result.output - C_ref)) + passed = max_err < 1e-2 + + status = "PASS" if passed else "FAIL" + print( + f" {size_str:<20} {selected_spec.name:<25} {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}" + ) + results.append( + (size_str, selected_spec.name, passed, result.time_ms, result.tflops) + ) + + cleanup_gemm() + + # ========================================================================= + # Summary + # ========================================================================= + print("\n" + "=" * 75) + print(" SUMMARY") + print("=" * 75) + + passed = sum(1 for r in results if r[2]) + failed = len(results) - passed + + print(f"\n Strategy: {strategy.value}") + print(f" Results: {passed}/{len(results)} tests passed") + + # Show kernel selection distribution + kernel_usage = {} + for r in results: + kernel_usage[r[1]] = kernel_usage.get(r[1], 0) + 1 + + print("\n Kernel Selection Distribution:") + for kernel, count in sorted(kernel_usage.items(), key=lambda x: -x[1]): + print(f" {kernel}: {count} times") + + if results: + valid_results = [r for r in results if r[2]] + if valid_results: + avg_tflops = sum(r[4] for r in valid_results) / len(valid_results) + print(f"\n Average TFLOPS: {avg_tflops:.2f}") + + if failed == 0: + print("\n *** ALL TESTS PASSED ***") + else: + print(f"\n *** {failed} TESTS FAILED ***") + + print("=" * 75) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/09_multi_registry.py b/dispatcher/examples/gemm/python/09_multi_registry.py new file mode 100644 index 0000000000..97cbce3497 --- /dev/null +++ b/dispatcher/examples/gemm/python/09_multi_registry.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 09: Multiple Registries + +Demonstrates multiple registries for different optimization targets. + +Complexity: ★★★★★ + +Usage: + python3 09_multi_registry.py + python3 09_multi_registry.py --help + python3 09_multi_registry.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + Registry, + Dispatcher, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Multiple Registries Example - optimization-specific registries", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 09_multi_registry.py # Default FP16 + python3 09_multi_registry.py --dtype bf16 # BF16 mode + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 09: Multiple Registries") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup base dispatcher + # ========================================================================= + print("\nStep 1: Setup Base Dispatcher") + + base_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(base_config, registry_name="base", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + lib = setup.lib + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 2: Define configs for different optimization targets + # ========================================================================= + print("\nStep 2: Define Optimization Targets") + + compute_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=256, + tile_n=256, + tile_k=64, + wave_m=4, + wave_n=4, + pipeline="compv4", + gfx_arch=args.arch, + ) + memory_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + wave_m=2, + wave_n=2, + pipeline="compv4", + gfx_arch=args.arch, + ) + latency_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=64, + tile_n=64, + tile_k=32, + wave_m=1, + wave_n=1, + pipeline="compv3", + gfx_arch=args.arch, + ) + + print(f" Compute: {compute_config.tile_str} (large matrices)") + print(f" Memory: {memory_config.tile_str} (medium matrices)") + print(f" Latency: {latency_config.tile_str} (small matrices)") + + # ========================================================================= + # Step 3: Create registries + # ========================================================================= + print("\nStep 3: Create Registries") + + compute_registry = Registry(name="compute", lib=lib) + compute_registry.register_kernel(compute_config) + + memory_registry = Registry(name="memory", lib=lib) + memory_registry.register_kernel(memory_config) + + latency_registry = Registry(name="latency", lib=lib) + latency_registry.register_kernel(latency_config) + + # ========================================================================= + # Step 4: Create dispatchers + # ========================================================================= + print("\nStep 4: Create Dispatchers") + + compute_dispatcher = Dispatcher(registry=compute_registry, lib=lib) + memory_dispatcher = Dispatcher(registry=memory_registry, lib=lib) + latency_dispatcher = Dispatcher(registry=latency_registry, lib=lib) + + print(f" {compute_dispatcher}") + print(f" {memory_dispatcher}") + print(f" {latency_dispatcher}") + + # ========================================================================= + # Step 5: Smart dispatcher selection + # ========================================================================= + print("\nStep 5: Smart Dispatcher Selection") + + def select_dispatcher(M: int, N: int, K: int) -> Dispatcher: + elements = M * N + if elements >= 4096 * 4096: + return compute_dispatcher + elif elements >= 1024 * 1024: + return memory_dispatcher + else: + return latency_dispatcher + + test_sizes = [ + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + ] + + print(f"\n {'Size':<20} {'Registry':>10} {'Time (ms)':>12} {'TFLOPS':>10}") + print(" " + "-" * 55) + + for M, N, K in test_sizes: + dispatcher = select_dispatcher(M, N, K) + + if not dispatcher.is_supported(M, N, K): + continue + + A = np.random.randn(M, K).astype(np_dtype) * 0.1 + B = np.random.randn(K, N).astype(np_dtype) * 0.1 + + result = dispatcher.run(A, B, M, N, K) + + if result.success: + print( + f" {M}x{N}x{K:<10} {dispatcher.registry.name:>10} " + f"{result.time_ms:>12.4f} {result.tflops:>10.2f}" + ) + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + print("Multi-Registry Pattern:") + print("=" * 60) + print(" 1. Define KernelConfig for each optimization target") + print(" 2. Create Registry for each target") + print(" 3. Register configs to appropriate registries") + print(" 4. Create Dispatcher for each registry") + print(" 5. Select dispatcher based on problem characteristics") + print(" 6. Run GEMM with selected dispatcher") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/10_advanced_benchmark.py b/dispatcher/examples/gemm/python/10_advanced_benchmark.py new file mode 100644 index 0000000000..e16e4e271f --- /dev/null +++ b/dispatcher/examples/gemm/python/10_advanced_benchmark.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 10: Advanced Benchmarking with Full Control + +This example demonstrates all available benchmark parameters: + - warmup: Number of warmup iterations (default: 5) + - repeat: Number of benchmark iterations (default: 20) + - flush_cache: Flush GPU cache between iterations (default: False) + - timer: Timer type - "gpu" (default) or "cpu" + - init: Initialization method - "random", "linear", "constant" + +Usage: + python3 10_advanced_benchmark.py + python3 10_advanced_benchmark.py --warmup 10 --repeat 100 + python3 10_advanced_benchmark.py --init linear +""" + +import argparse +import sys +from pathlib import Path + +# Add paths for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Advanced GEMM benchmarking with full parameter control" + ) + + # Problem size + parser.add_argument("-m", type=int, default=2048, help="M dimension") + parser.add_argument("-n", type=int, default=2048, help="N dimension") + parser.add_argument("-k", type=int, default=2048, help="K dimension") + + # Benchmark parameters + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--repeat", type=int, default=20, help="Number of benchmark iterations" + ) + parser.add_argument( + "--flush-cache", action="store_true", help="Flush GPU cache between iterations" + ) + parser.add_argument( + "--timer", choices=["gpu", "cpu"], default="gpu", help="Timer type (gpu or cpu)" + ) + parser.add_argument( + "--init", + choices=["random", "linear", "constant"], + default="random", + help="Initialization method", + ) + + # Kernel configuration + parser.add_argument("--dtype", default="fp16", help="Data type") + parser.add_argument("--pipeline", default="compv4", help="Pipeline type") + parser.add_argument("--arch", default="gfx942", help="GPU architecture") + + return parser.parse_args() + + +def initialize_matrix(shape, method, dtype): + """Initialize matrix with specified method""" + if method == "random": + return np.random.randn(*shape).astype(dtype) * 0.5 + elif method == "linear": + total = np.prod(shape) + return np.arange(total).reshape(shape).astype(dtype) / total + elif method == "constant": + return np.ones(shape, dtype=dtype) + else: + return np.random.randn(*shape).astype(dtype) + + +def main(): + args = parse_args() + + reset_for_example() + + print("=" * 70) + print("Example 10: Advanced GEMM Benchmarking") + print("=" * 70) + + # Show benchmark configuration + print("\nBenchmark Configuration:") + print(f" Problem Size: {args.m} x {args.n} x {args.k}") + print(f" Warmup: {args.warmup} iterations") + print(f" Repeat: {args.repeat} iterations") + print(f" Flush Cache: {args.flush_cache}") + print(f" Timer: {args.timer}") + print(f" Init Method: {args.init}") + print(f" Data Type: {args.dtype}") + print(f" Pipeline: {args.pipeline}") + print(f" Architecture: {args.arch}") + print() + + # Map dtype + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # Initialize matrices + print("Step 1: Initialize matrices...") + A = initialize_matrix((args.m, args.k), args.init, np_dtype) + B = initialize_matrix((args.k, args.n), args.init, np_dtype) + print(f" A: {A.shape} ({args.init})") + print(f" B: {B.shape} ({args.init})") + + # Create kernel config (does not include M/N/K - those are problem size) + print("\nStep 2: Create kernel configuration...") + kernel_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", # B is column-major for optimal performance + layout_c="row", + tile_m=128, + tile_n=128, + tile_k=32, + wave_m=2, + wave_n=2, + wave_k=1, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline=args.pipeline, + scheduler="intrawave", + epilogue="cshuffle", + gfx_arch=args.arch, + ) + print(f" Config: {args.dtype}, tile=128x128x32, {args.pipeline}") + + # Setup dispatcher + print("\nStep 3: Setup dispatcher...") + setup = setup_gemm_dispatcher( + config=kernel_config, + registry_name="benchmark_gemm", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + print(f" Library: {setup.lib.path if setup.lib else 'N/A'}") + print(f" Kernel: {setup.lib.get_kernel_name() if setup.lib else 'N/A'}") + + # Run benchmark with multiple iterations + print("\nStep 4: Run benchmark...") + print(f" Running {args.warmup} warmup + {args.repeat} benchmark iterations...") + + # Warmup + for _ in range(args.warmup): + _ = dispatcher.run(A, B, args.m, args.n, args.k) + + # Benchmark + times = [] + for _ in range(args.repeat): + result = dispatcher.run(A, B, args.m, args.n, args.k) + if result.success: + times.append(result.time_ms) + + if times: + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + + # Calculate TFLOPS + flops = 2 * args.m * args.n * args.k + avg_tflops = (flops / 1e12) / (avg_time / 1000) if avg_time > 0 else 0 + max_tflops = (flops / 1e12) / (min_time / 1000) if min_time > 0 else 0 + + # Calculate bandwidth (C has same dtype as A and B) + C_bytes = args.m * args.n * np.dtype(np_dtype).itemsize + bandwidth_gb = ( + (A.nbytes + B.nbytes + C_bytes) / 1e9 / (avg_time / 1000) + if avg_time > 0 + else 0 + ) + + print(f"\n *** BENCHMARK RESULTS ({args.repeat} iterations) ***") + print(f" Average Time: {avg_time:.4f} ms") + print(f" Min Time: {min_time:.4f} ms") + print(f" Max Time: {max_time:.4f} ms") + print(f" Avg TFLOPS: {avg_tflops:.2f}") + print(f" Peak TFLOPS: {max_tflops:.2f}") + print(f" Bandwidth: {bandwidth_gb:.2f} GB/s") + else: + print(" FAILED: No successful runs") + return 1 + + # Summary + print("\n" + "=" * 70) + print("BENCHMARK PARAMETERS REFERENCE") + print("=" * 70) + print(""" +Available parameters for GEMM benchmarking: + + --warmup N Number of warmup iterations (discard results) + Higher = more stable results, longer run time + Default: 5 + + --repeat N Number of benchmark iterations + Higher = more accurate average, longer run time + Default: 20 + + --flush-cache Flush GPU L2 cache between iterations + Use for memory-bound benchmarks + Default: off + + --timer {gpu,cpu} Timer type + gpu = HIP events (more accurate for GPU) + cpu = std::chrono (includes kernel launch overhead) + Default: gpu + + --init METHOD Matrix initialization + random = uniform random [-0.5, 0.5] + linear = sequential values + constant = all ones + Default: random + +Note: For C++ examples, these parameters are passed to stream_config: + + ck_tile::stream_config cfg{ + nullptr, // stream_id + true, // time_kernel + 1, // log_level + 5, // cold_niters (warmup) + 20, // nrepeat + true, // is_gpu_timer + false, // flush_cache + 1 // rotating_count + }; +""") + + # Cleanup + cleanup_gemm() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/11_json_import.py b/dispatcher/examples/gemm/python/11_json_import.py new file mode 100644 index 0000000000..06743af406 --- /dev/null +++ b/dispatcher/examples/gemm/python/11_json_import.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 11: JSON-based Kernel Configuration Import + +Demonstrates loading kernel configurations from JSON files, similar to tile_engine. +This enables easy customization of kernel sets without modifying code. + +Key Features: + - Load tile configs from JSON (compatible with tile_engine format) + - Generate kernel sets from configuration + - Use arch_filter validation on loaded configs + - Export to C++ DECL_KERNEL_SET format + +Complexity: ★★★☆☆ + +Usage: + python3 11_json_import.py + python3 11_json_import.py --config my_kernels.json + python3 11_json_import.py --export-cpp +""" + +import sys +import argparse +import json +from pathlib import Path + +# Add codegen to path for kernel_config_loader +script_dir = Path(__file__).parent.resolve() +sys.path.insert(0, str(script_dir.parent.parent.parent / "codegen")) +sys.path.insert(0, str(script_dir.parent.parent.parent / "python")) + +from kernel_config_loader import ( # noqa: E402 + load_kernel_configs, + KernelConfig, + generate_cpp_kernel_set_declaration, +) + +from ctypes_utils import ( # noqa: E402 + KernelConfig as DispatcherKernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, + validate_kernel_config, +) + +# Sample JSON configuration (embedded for demonstration) +SAMPLE_JSON_CONFIG = { + "_comment": "Sample kernel configuration for GEMM", + "kernel_set_name": "inference_kernels", + "datatype": {"a": "fp16", "b": "fp16", "c": "fp16", "acc": "fp32"}, + "layout": "rcr", + "tile_config": { + "tile_m": {"values": [128, 256]}, + "tile_n": {"values": [128, 256]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32]}, + "warp_tile_n": {"values": [32]}, + "warp_tile_k": {"values": [16]}, + }, + "trait_config": { + "pipeline": {"values": ["compv4"]}, + "scheduler": {"values": ["intrawave"]}, + "epilogue": {"values": ["cshuffle"]}, + "pad_m": {"values": [False]}, + "pad_n": {"values": [False]}, + "pad_k": {"values": [False]}, + }, + "gpu_targets": ["gfx942"], +} + + +def print_section(title: str): + """Print a section header""" + print(f"\n{'=' * 70}") + print(f" {title}") + print(f"{'=' * 70}\n") + + +def convert_to_dispatcher_config( + config: KernelConfig, arch: str = "gfx942" +) -> DispatcherKernelConfig: + """Convert kernel_config_loader.KernelConfig to dispatcher KernelConfig""" + return DispatcherKernelConfig( + dtype_a=config.dtype_a, + dtype_b=config.dtype_b, + dtype_c=config.dtype_c, + dtype_acc=config.dtype_acc, + tile_m=config.tile.tile_m, + tile_n=config.tile.tile_n, + tile_k=config.tile.tile_k, + wave_m=config.tile.warp_m, + wave_n=config.tile.warp_n, + wave_k=config.tile.warp_k, + warp_m=config.tile.warp_tile_m, + warp_n=config.tile.warp_tile_n, + warp_k=config.tile.warp_tile_k, + pipeline=config.trait.pipeline, + scheduler=config.trait.scheduler, + epilogue=config.trait.epilogue, + pad_m=config.trait.pad_m, + pad_n=config.trait.pad_n, + pad_k=config.trait.pad_k, + gfx_arch=arch, + variant=config.variant, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="JSON Kernel Configuration Import Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 11_json_import.py # Use embedded sample config + python3 11_json_import.py --config my.json # Load from file + python3 11_json_import.py --export-cpp # Generate C++ declarations + python3 11_json_import.py --validate # Validate configs against arch + """, + ) + parser.add_argument( + "--config", + type=str, + help="Path to JSON configuration file (uses embedded sample if not provided)", + ) + parser.add_argument( + "--export-cpp", + action="store_true", + help="Export kernel set as C++ DECL_KERNEL_SET", + ) + parser.add_argument( + "--validate", + action="store_true", + help="Validate all configurations against arch filter", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target GPU architecture (default: gfx942)", + ) + args = parser.parse_args() + + reset_for_example() + + print_section("Example 11: JSON Kernel Configuration Import") + + # ========================================================================= + # Step 1: Load configuration from JSON + # ========================================================================= + print("Step 1: Load Kernel Configuration from JSON") + print("-" * 50) + + if args.config: + config_path = Path(args.config) + if not config_path.exists(): + print(f" ERROR: Config file not found: {config_path}") + return 1 + print(f" Loading from: {config_path}") + config_set = load_kernel_configs(config_path) + else: + # Use embedded sample config + print(" Using embedded sample configuration") + # Write to temp file and load + temp_path = Path("/tmp/sample_gemm_config.json") + with open(temp_path, "w") as f: + json.dump(SAMPLE_JSON_CONFIG, f, indent=2) + config_set = load_kernel_configs(temp_path) + + print(f"\n Kernel Set Name: {config_set.name}") + print( + f" Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}" + ) + print(f" Layout: {config_set.layout}") + print(f" GPU Targets: {config_set.gpu_targets}") + print(f" Total Configurations: {config_set.config_count()}") + + # ========================================================================= + # Step 2: Display configuration details + # ========================================================================= + print("\nStep 2: Configuration Details") + print("-" * 50) + + print("\n Tile Configurations:") + print(f" tile_m: {config_set.tile_m_values}") + print(f" tile_n: {config_set.tile_n_values}") + print(f" tile_k: {config_set.tile_k_values}") + print( + f" warp (wave): {config_set.warp_m_values}x{config_set.warp_n_values}x{config_set.warp_k_values}" + ) + print( + f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}" + ) + + print("\n Trait Configurations:") + print(f" pipeline: {config_set.pipeline_values}") + print(f" scheduler: {config_set.scheduler_values}") + print(f" epilogue: {config_set.epilogue_values}") + print( + f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}" + ) + + # ========================================================================= + # Step 3: Generate and display kernel names + # ========================================================================= + print("\nStep 3: Generated Kernel Names") + print("-" * 50) + + configs = list(config_set.generate_configs()) + for i, config in enumerate(configs[:5]): + print(f" {i + 1}. {config.kernel_name()}") + if len(configs) > 5: + print(f" ... and {len(configs) - 5} more configurations") + + # ========================================================================= + # Step 4: Validate against arch filter (optional) + # ========================================================================= + if args.validate: + print("\nStep 4: Architecture Validation") + print("-" * 50) + + valid_count = 0 + invalid_count = 0 + + for config in configs: + disp_config = convert_to_dispatcher_config(config, args.arch) + result = validate_kernel_config(disp_config) + + if result.is_valid: + valid_count += 1 + else: + invalid_count += 1 + if invalid_count <= 3: # Show first 3 invalid + print(f"\n ✗ Invalid: {config.kernel_name()}") + for error in result.errors: + print(f" Error: {error}") + + print("\n Validation Summary:") + print(f" ✓ Valid: {valid_count}") + print(f" ✗ Invalid: {invalid_count}") + print(f" Total: {len(configs)}") + + # ========================================================================= + # Step 5: Export to C++ (optional) + # ========================================================================= + if args.export_cpp: + print("\nStep 5: C++ Export") + print("-" * 50) + print("\n // Generated DECL_KERNEL_SET from JSON config:") + print(" // " + "=" * 56) + cpp_code = generate_cpp_kernel_set_declaration(config_set) + for line in cpp_code.split("\n"): + print(f" {line}") + + # ========================================================================= + # Step 6: Use first config with dispatcher (demo) + # ========================================================================= + print("\nStep 6: Dispatcher Integration Demo") + print("-" * 50) + + if configs: + first_config = configs[0] + disp_config = convert_to_dispatcher_config(first_config, args.arch) + + print( + f"\n Using first config: {first_config.tile.tile_m}x{first_config.tile.tile_n}x{first_config.tile.tile_k}" + ) + + setup = setup_gemm_dispatcher( + disp_config, registry_name="json_import", verbose=False + ) + if setup.success: + print(" ✓ Dispatcher setup successful") + print( + f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}" + ) + else: + print(f" ⚠ Dispatcher setup: {setup.error}") + print(" (This is expected if kernels aren't generated)") + + # ========================================================================= + # Summary + # ========================================================================= + print_section("Summary") + print(" JSON configuration allows easy kernel set customization:") + print(" - Define tile sizes and ranges") + print(" - Specify trait combinations (pipeline, scheduler, etc.)") + print(" - Target multiple GPU architectures") + print(" - Export to C++ DECL_KERNEL_SET for static compilation") + print() + print(" JSON Format (tile_engine compatible):") + print(' {"tile_config": {"tile_m": {"values": [128, 256]}, ...},') + print(' "trait_config": {"pipeline": {"values": ["compv4"]}, ...}}') + print() + print(" Usage:") + print(" config_set = load_kernel_configs('my_kernels.json')") + print(" for config in config_set.generate_configs():") + print(" # Use config for codegen or dispatcher setup") + + cleanup_gemm() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/README.md b/dispatcher/examples/gemm/python/README.md new file mode 100644 index 0000000000..0a83f3533f --- /dev/null +++ b/dispatcher/examples/gemm/python/README.md @@ -0,0 +1,299 @@ +# GEMM Python Examples + +CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations. + +> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md) + +## Quick Start + +### Build Library + +```bash +cd /path/to/composable_kernel/dispatcher +mkdir -p build && cd build + +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Build Python library (kernels generated automatically) +make dispatcher_gemm_lib -j$(nproc) +``` + +### Run Examples + +```bash +cd /path/to/composable_kernel/dispatcher + +python3 examples/gemm/python/01_basic_gemm.py +python3 examples/gemm/python/04_validation.py +python3 examples/gemm/python/07_stress_test.py +python3 examples/gemm/python/08_heuristics.py +``` + +## Examples + +| Example | Description | +|---------|-------------| +| [01_basic_gemm.py](01_basic_gemm.py) | Basic GEMM with multi-kernel support | +| [02_batch_gemm.py](02_batch_gemm.py) | Batched GEMM operations | +| [03_benchmark.py](03_benchmark.py) | Performance benchmarking | +| [04_validation.py](04_validation.py) | CPU reference validation | +| [05_numpy_integration.py](05_numpy_integration.py) | NumPy array integration | +| [06_json_export.py](06_json_export.py) | Registry JSON export | +| [07_stress_test.py](07_stress_test.py) | Multi-kernel stress testing | +| [08_heuristics.py](08_heuristics.py) | Heuristic-based kernel selection | +| [09_multi_registry.py](09_multi_registry.py) | Multiple registries | +| [10_advanced_benchmark.py](10_advanced_benchmark.py) | Advanced benchmark with full control | +| [11_json_import.py](11_json_import.py) | Import kernels from JSON | + +## Example Details + +### 01_basic_gemm.py - Basic GEMM +Demonstrates the Python API with multi-kernel support: + +```python +from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table + +# Define multiple kernel configurations +kernels = [ + KernelConfig( + tile_m=128, tile_n=128, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv3", scheduler="intrawave" + ), + KernelConfig( + tile_m=256, tile_n=256, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave" + ), +] + +# Display configurations +print_kernel_config_table(kernels) + +# Set up dispatcher with all kernels +lib, dispatcher, registry = setup_gemm_dispatcher(kernels) + +# Run GEMM +elapsed_ms = run_gemm(lib, M, N, K, ...) +``` + +### 02_batch_gemm.py - Batch GEMM +Batched matrix multiplication: +- Multiple independent GEMM operations +- Batch dimension handling + +### 03_benchmark.py - Benchmarking +Performance measurement: +- GPU timing +- TFLOPS calculation +- Multiple iterations + +### 04_validation.py - Validation +Correctness verification: +- NumPy reference implementation +- Tolerance-based validation +- Error reporting + +### 05_numpy_integration.py - NumPy Integration +Seamless NumPy integration: +- NumPy arrays to GPU buffers +- Results back to NumPy +- Automatic type conversion + +### 06_json_export.py - JSON Export +Registry serialization for tool integration: +- Export kernel configurations +- Machine-readable format + +### 07_stress_test.py - Stress Testing +Comprehensive multi-kernel stress testing: + +```python +from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table + +# Define 48 unique kernel configurations +kernels = [ + KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...), + KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...), + KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...), + # ... many more configurations +] + +# Test each kernel +for i, kernel in enumerate(kernels): + lib, dispatcher, registry = setup_gemm_dispatcher([kernel]) + result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel + print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}") +``` + +**Features:** +- 48 unique kernel configurations +- Various tile sizes, pipelines, and schedulers +- Per-kernel validation with unique random seeds +- Performance reporting + +### 08_heuristics.py - Heuristic Selection +Custom kernel selection based on problem characteristics: + +```python +# Define kernel pools for different strategies +SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...] +LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...] +COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...] +MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...] + +# Size-based heuristic +def size_based_heuristic(M, N, K): + if M * N < 512 * 512: + return SMALL_KERNELS + else: + return LARGE_KERNELS + +# Strategy-based selection +def compute_strategy(): + return COMPUTE_KERNELS # Optimized for compute-bound problems + +def memory_strategy(): + return MEMORY_KERNELS # Optimized for memory-bound problems + +# Test different strategies +for strategy in [size_based_heuristic, compute_strategy, memory_strategy]: + kernels = strategy(M, N, K) + lib, dispatcher, registry = setup_gemm_dispatcher(kernels) + elapsed_ms = run_gemm(lib, M, N, K, ...) +``` + +**Features:** +- 24 kernel configurations across 6 categories +- Size-based heuristic (small vs large) +- Optimization strategies (compute, memory, latency) +- Performance comparison across strategies + +### 09_multi_registry.py - Multiple Registries +Separate registries for different workloads: +- Compute-optimized registry +- Latency-optimized registry +- Dynamic registry selection + +### 10_advanced_benchmark.py - Advanced Benchmark +Full control over benchmark parameters: +- Warmup iterations +- Benchmark iterations +- Statistical analysis + +### 11_json_import.py - JSON Import +Import kernel configurations from JSON: +- External configuration files +- Dynamic kernel loading + +## Utility Module: ctypes_utils.py + +```python +from ctypes_utils import ( + KernelConfig, # Single kernel configuration + setup_gemm_dispatcher, # Set up dispatcher with kernels + print_kernel_config_table, # Display kernel configurations + Dispatcher, # High-level dispatcher + Registry, # Kernel registry + Validator, # Validation utilities +) +``` + +### KernelConfig + +```python +config = KernelConfig( + # Tile sizes + tile_m=256, tile_n=256, tile_k=32, + # Wave configuration + wave_m=2, wave_n=2, wave_k=1, + # Warp tile sizes + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + # Pipeline and scheduler + pipeline="compv4", # "compv3" or "compv4" + scheduler="intrawave", # "intrawave" or "interwave" + # Optional + epilogue="default", + padding=True, + double_buffer=True, +) +``` + +### setup_gemm_dispatcher + +```python +# Single kernel +lib, dispatcher, registry = setup_gemm_dispatcher(config) + +# Multiple kernels +lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...]) + +# With auto-rebuild +lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True) +``` + +### print_kernel_config_table + +```python +kernels = [config1, config2, config3] +print_kernel_config_table(kernels) +# Output: +# +----+-------+-------+-------+--------+-----------+ +# | # | Tile | Wave | Warp | Pipe | Scheduler | +# +----+-------+-------+-------+--------+-----------+ +# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave | +# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave | +# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave | +# +----+-------+-------+-------+--------+-----------+ +``` + +### GPU Memory Management + +```python +import ctypes +import numpy as np + +# Load HIP library +hip = ctypes.CDLL("libamdhip64.so") + +# Allocate GPU memory +gpu_ptr = ctypes.c_void_p() +hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes) + +# Copy to GPU (1 = hipMemcpyHostToDevice) +hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1) + +# Copy back (2 = hipMemcpyDeviceToHost) +hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2) + +# Free +hip.hipFree(gpu_ptr) +``` + +## Performance Testing + +Test compilation performance with different kernel counts: + +```bash +# Test with 10 kernels (~15s compile time) +python3 01_basic_gemm.py --num-kernels 10 + +# Test with 20 kernels (~25s compile time) +python3 01_basic_gemm.py --num-kernels 20 + +# Test with 48 kernels (~50s compile time) +python3 01_basic_gemm.py --num-kernels 48 +``` + +Compilation time scales roughly linearly with kernel count. + +## Related Documentation + +- [C++ GEMM Examples](../cpp/README.md) +- [Python Conv Examples](../../conv/python/README.md) +- [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/gemm/python/kernels.json b/dispatcher/examples/gemm/python/kernels.json new file mode 100644 index 0000000000..214b1cc42c --- /dev/null +++ b/dispatcher/examples/gemm/python/kernels.json @@ -0,0 +1,80 @@ +{ + "registry": "export_demo", + "kernel_count": 3, + "kernels": [ + { + "tile": "128x128x32", + "dtypes": { + "A": "fp16", + "B": "fp16", + "C": "fp16" + }, + "layout": "rcr", + "pipeline": "compv4", + "target": "gfx942" + }, + { + "tile": "256x256x64", + "dtypes": { + "A": "fp16", + "B": "fp16", + "C": "fp16" + }, + "layout": "rcr", + "pipeline": "compv4", + "target": "gfx942" + }, + { + "tile": "64x64x32", + "dtypes": { + "A": "fp16", + "B": "fp16", + "C": "fp16" + }, + "layout": "rcr", + "pipeline": "compv4", + "target": "gfx942" + } + ], + "cpp_registry": { + "metadata": { + "timestamp": "Dec 4 2025 06:23:15", + "total_kernels": 1, + "export_version": "1.0", + "dispatcher_version": "1.0.0" + }, + "statistics": { + "by_datatype": {}, + "by_pipeline": {}, + "by_scheduler": {} + }, + "kernels": [ + { + "identifier": "128x128x32_2x2x1_32x32x16_nopers", + "name": "gemm_fp16_rcrr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16", + "algorithm": { + "tile_shape": { + "m": 128, + "n": 128, + "k": 32 + }, + "wave_shape": { + "m": 2, + "n": 2, + "k": 1 + }, + "warp_tile_shape": { + "m": 32, + "n": 32, + "k": 16 + }, + "block_size": 256, + "persistent": false, + "double_buffer": true, + "preshuffle": false, + "transpose_c": false + } + } + ] + } +} \ No newline at end of file diff --git a/dispatcher/include/ck_tile/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher.hpp new file mode 100644 index 0000000000..98d8bb9333 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher.hpp @@ -0,0 +1,19 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +/// Main dispatcher header - includes all core components +/// Use this for convenient access to the full dispatcher API + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/kernel_config.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" +#include "ck_tile/dispatcher/utils.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher/README.md b/dispatcher/include/ck_tile/dispatcher/README.md new file mode 100644 index 0000000000..db3ce996a9 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/README.md @@ -0,0 +1,161 @@ +# CK Tile Dispatcher - C++ Headers + +C++ API for the CK Tile dispatcher. + +> **See also:** [Main Dispatcher README](../../../../README.md) for installation and core concepts. + +## File Organization + +``` +dispatcher/ +├── dispatcher.hpp # Main dispatcher (kernel selection) +├── registry.hpp # Kernel registry (storage & lookup) +├── problem.hpp # Problem specification +├── kernel_key.hpp # Kernel configuration key +├── kernel_instance.hpp # Kernel instance interface +├── utils.hpp # Utilities (timers, GPU buffers) +│ +└── backends/ # Backend implementations + ├── generated_tile_backend.hpp # CK Tile kernels (production) + └── tile_backend.hpp # Tile backend base +``` + +## Quick Start + +```cpp +#include "ck_tile/dispatcher.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +int main() { + // 1. Build kernel key + KernelKeyBuilder builder = KernelKeyBuilder::fp16_rcr(); + builder.tile_m = 128; + builder.tile_n = 128; + builder.tile_k = 32; + KernelKey key = builder.build(); + + // 2. Register kernel + auto kernel = create_generated_tile_kernel<...>(key, "my_kernel"); + Registry::instance().register_kernel(kernel, Priority::High); + + // 3. Run GEMM + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + float time_ms = dispatcher.run(a_ptr, b_ptr, c_ptr, problem, nullptr); +} +``` + +## Core Classes + +### KernelKey (`kernel_key.hpp`) + +Uniquely identifies a kernel configuration: + +```cpp +KernelKeyBuilder builder; +builder.dtype_a = DataType::FP16; +builder.layout_a = LayoutTag::Row; +builder.tile_m = 256; +builder.pipeline = Pipeline::CompV4; +KernelKey key = builder.build(); +``` + +### Registry (`registry.hpp`) + +Thread-safe kernel storage: + +```cpp +auto& registry = Registry::instance(); +registry.register_kernel(kernel, Priority::High); +registry.get_kernel_count(); +registry.export_json(); +``` + +### Dispatcher (`dispatcher.hpp`) + +Kernel selection and execution: + +```cpp +Dispatcher dispatcher; + +// Strategies +dispatcher.set_strategy(SelectionStrategy::FirstFit); +dispatcher.set_strategy(SelectionStrategy::Heuristic); + +// Run +float time = dispatcher.run(a, b, c, problem, stream); +``` + +### Problem (`problem.hpp`) + +GEMM problem specification: + +```cpp +Problem problem(M, N, K); +problem.batch_size = 4; +problem.alpha = 1.0f; +problem.beta = 0.0f; + +// Auto-inference +auto p = Problem::from_ab(a_rows, a_cols, b_rows, b_cols, trans_a, trans_b); +``` + +## Utilities (`utils.hpp`) + +### GPU Memory + +```cpp +GpuBuffer buffer(size); +buffer.copy_from_host(host_ptr); +buffer.copy_to_host(host_ptr); +buffer.zero(); +``` + +### Timing + +```cpp +GpuTimer timer; +timer.start(); +// kernel... +timer.stop(); +float ms = timer.elapsed_ms(); +``` + +### Quick Helpers + +```cpp +// Create FP16 RCR key +auto key = create_fp16_rcr_key(tile_m, tile_n, tile_k, ...); + +// Performance +double tflops = calculate_tflops(M, N, K, time_ms); + +// Validation +auto result = validate_result(gpu_ptr, cpu_ptr, size); +``` + +## Backend + +### Generated Tile Backend + +```cpp +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +auto kernel = create_generated_tile_kernel< + SelectedKernel, ADataType, BDataType, CDataType, AccDataType +>(key, name); +``` + +## Best Practices + +1. Use `Release` build for performance +2. Register kernels at startup +3. Use `Priority::High` for hand-tuned kernels +4. Reuse dispatcher instances +5. Clear registry between test runs + +--- + +> **More info:** See [../../../../README.md](../../../../README.md) for full documentation. diff --git a/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp b/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp new file mode 100644 index 0000000000..33a864a649 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp @@ -0,0 +1,393 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Architecture-Specific Kernel Filtering for CK Tile Dispatcher + * + * Provides GPU architecture-aware validation of kernel configurations. + * Uses arch_specs_generated.hpp as single source of truth (generated from arch_specs.json). + * + * Usage: + * ArchFilter filter("gfx942"); + * + * // Check if a kernel configuration is valid + * if (filter.is_valid(kernel_key)) { + * registry.register_kernel(kernel); + * } + * + * // Get validation result with error details + * auto result = filter.validate(kernel_key); + * if (!result.valid) { + * for (const auto& error : result.errors) { + * std::cerr << error << "\n"; + * } + * } + * + * Adding New GPU Support: + * 1. Edit dispatcher/codegen/arch_specs.json + * 2. Run: python dispatcher/codegen/generate_arch_specs.py + * 3. Rebuild the dispatcher + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/arch_specs_generated.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +// ============================================================================= +// Re-export from generated header for convenience +// ============================================================================= + +// Use the generated types and functions from arch_specs namespace +using GpuArch = arch_specs::GpuArch; +using WarpConfig = arch_specs::WarpConfig; +using WarpTileConfig = std::array; + +// Re-export string conversion functions +using arch_specs::arch_to_string; +using arch_specs::element_size; +using arch_specs::get_lds_capacity; +using arch_specs::get_supported_warp_configs; +using arch_specs::is_trait_unsupported; +using arch_specs::string_to_arch; + +// ============================================================================= +// Additional Helper Functions +// ============================================================================= + +/// Get supported warp tile configurations for arch and data types +/// This function wraps the generated data with runtime logic +inline std::vector get_supported_warp_tiles(GpuArch arch, + DataType dtype_a, + DataType dtype_b, + [[maybe_unused]] DataType dtype_c) +{ + // Common FP16 configurations (from arch_specs.json) + std::vector fp16_configs = { + {32, 32, 8}, {16, 16, 16}, {32, 32, 16}, {16, 16, 32}, {4, 64, 16}, {64, 4, 16}}; + + // FP8 configurations + std::vector fp8_gfx942 = { + {32, 32, 16}, {32, 32, 32}, {16, 16, 32}, {16, 16, 64}}; + std::vector fp8_gfx950 = { + {32, 32, 16}, {32, 32, 32}, {16, 16, 32}, {16, 16, 64}, {16, 16, 128}, {32, 32, 64}}; + + // INT8 configurations + std::vector int8_configs = {{16, 16, 32}, {32, 32, 16}}; + + // GFX1201 only supports limited FP16 + std::vector rdna4_fp16 = {{16, 16, 16}}; + + // Match based on architecture and data types + if(dtype_a == DataType::FP16 && dtype_b == DataType::FP16) + { + if(arch == GpuArch::GFX_1201) + return rdna4_fp16; + return fp16_configs; + } + if(dtype_a == DataType::BF16 && dtype_b == DataType::BF16) + { + if(arch == GpuArch::GFX_1201) + return {}; // Not supported on RDNA4 + return fp16_configs; // Same as FP16 + } + if(dtype_a == DataType::FP8 || dtype_a == DataType::BF8) + { + if(arch == GpuArch::GFX_950) + return fp8_gfx950; + if(arch == GpuArch::GFX_942) + return fp8_gfx942; + if(arch == GpuArch::GFX_90A) + return {{32, 32, 16}, {32, 32, 32}}; + } + if(dtype_a == DataType::INT8 && dtype_b == DataType::INT8) + { + if(arch == GpuArch::GFX_942) + return int8_configs; + } + + return {}; // Unknown combination +} + +// ============================================================================= +// Validation Result +// ============================================================================= + +/// Result of kernel validation +struct ValidationResult +{ + bool valid = true; + std::vector errors; + std::vector warnings; + + explicit operator bool() const { return valid; } + + void add_error(const std::string& msg) + { + errors.push_back(msg); + valid = false; + } + + void add_warning(const std::string& msg) { warnings.push_back(msg); } +}; + +// ============================================================================= +// Architecture Filter +// ============================================================================= + +/** + * Architecture-specific kernel filter. + * + * Validates kernel configurations against GPU architecture constraints + * including warp configurations, warp tiles, LDS capacity, and traits. + */ +class ArchFilter +{ + public: + /** + * Create architecture filter. + * @param arch Target GPU architecture + * @param strict_mode If true, unknown configurations are rejected + */ + explicit ArchFilter(GpuArch arch, bool strict_mode = false) + : arch_(arch), strict_mode_(strict_mode) + { + } + + /** + * Create architecture filter from string. + * @param arch_str GPU architecture string (e.g., "gfx942") + * @param strict_mode If true, unknown configurations are rejected + */ + explicit ArchFilter(const std::string& arch_str, bool strict_mode = false) + : arch_(string_to_arch(arch_str)), strict_mode_(strict_mode) + { + } + + /** + * Quick validation check. + * @param key Kernel configuration key + * @return true if configuration is valid for this architecture + */ + [[nodiscard]] bool is_valid(const KernelKey& key) const { return validate(key).valid; } + + /** + * Detailed validation with error messages. + * @param key Kernel configuration key + * @return ValidationResult with valid flag and error/warning messages + */ + [[nodiscard]] ValidationResult validate(const KernelKey& key) const + { + ValidationResult result; + + // Check architecture match + if(!key.gfx_arch.empty() && string_to_arch(key.gfx_arch) != arch_) + { + result.add_warning("Kernel compiled for different architecture: " + key.gfx_arch); + } + + // Validate dimensions + validate_dimensions(key, result); + + // Validate warp configuration + validate_warp_config(key, result); + + // Validate warp tile configuration + validate_warp_tiles(key, result); + + // Validate trait combination + validate_traits(key, result); + + // Validate LDS capacity + validate_lds(key, result); + + return result; + } + + /// Get target architecture + [[nodiscard]] GpuArch get_arch() const { return arch_; } + + /// Get target architecture as string + [[nodiscard]] std::string get_arch_string() const { return arch_to_string(arch_); } + + private: + void validate_dimensions(const KernelKey& key, ValidationResult& result) const + { + const auto& alg = key.algorithm; + + // Check positive dimensions + if(alg.tile_shape.m <= 0 || alg.tile_shape.n <= 0 || alg.tile_shape.k <= 0) + { + result.add_error("Tile dimensions must be positive"); + return; + } + + // Check warp tiles fit in block tiles + int warp_m_coverage = alg.wave_shape.m * alg.warp_tile_shape.m; + int warp_n_coverage = alg.wave_shape.n * alg.warp_tile_shape.n; + int warp_k_coverage = alg.wave_shape.k * alg.warp_tile_shape.k; + + if(warp_m_coverage > alg.tile_shape.m) + { + result.add_error("warp_m * warp_tile_m > tile_m: " + std::to_string(warp_m_coverage) + + " > " + std::to_string(alg.tile_shape.m)); + } + if(warp_n_coverage > alg.tile_shape.n) + { + result.add_error("warp_n * warp_tile_n > tile_n: " + std::to_string(warp_n_coverage) + + " > " + std::to_string(alg.tile_shape.n)); + } + if(warp_k_coverage > alg.tile_shape.k) + { + result.add_error("warp_k * warp_tile_k > tile_k: " + std::to_string(warp_k_coverage) + + " > " + std::to_string(alg.tile_shape.k)); + } + + // Check alignment + if(alg.tile_shape.m % warp_m_coverage != 0) + { + result.add_error("tile_m must be divisible by warp_m * warp_tile_m"); + } + if(alg.tile_shape.n % warp_n_coverage != 0) + { + result.add_error("tile_n must be divisible by warp_n * warp_tile_n"); + } + if(alg.tile_shape.k % warp_k_coverage != 0) + { + result.add_error("tile_k must be divisible by warp_k * warp_tile_k"); + } + } + + void validate_warp_config(const KernelKey& key, ValidationResult& result) const + { + auto supported = get_supported_warp_configs(arch_); + if(supported.empty()) + { + if(strict_mode_) + { + result.add_error("No warp configurations defined for " + get_arch_string()); + } + else + { + result.add_warning("No warp configurations defined for " + get_arch_string()); + } + return; + } + + WarpConfig current = { + key.algorithm.wave_shape.m, key.algorithm.wave_shape.n, key.algorithm.wave_shape.k}; + + bool found = false; + for(const auto& cfg : supported) + { + if(cfg == current) + { + found = true; + break; + } + } + + if(!found) + { + result.add_error("Invalid warp configuration [" + std::to_string(current[0]) + ", " + + std::to_string(current[1]) + ", " + std::to_string(current[2]) + + "] for " + get_arch_string()); + } + } + + void validate_warp_tiles(const KernelKey& key, ValidationResult& result) const + { + auto supported = get_supported_warp_tiles( + arch_, key.signature.dtype_a, key.signature.dtype_b, key.signature.dtype_c); + + if(supported.empty()) + { + // Unknown data type combination - allow with warning + result.add_warning("No warp tile combinations defined for data types"); + return; + } + + WarpTileConfig current = {key.algorithm.warp_tile_shape.m, + key.algorithm.warp_tile_shape.n, + key.algorithm.warp_tile_shape.k}; + + bool found = false; + for(const auto& cfg : supported) + { + if(cfg == current) + { + found = true; + break; + } + } + + if(!found) + { + result.add_error("Invalid warp tile [" + std::to_string(current[0]) + ", " + + std::to_string(current[1]) + ", " + std::to_string(current[2]) + + "] for " + get_arch_string()); + } + } + + void validate_traits(const KernelKey& key, ValidationResult& result) const + { + if(is_trait_unsupported( + key.algorithm.pipeline, key.algorithm.epilogue, key.algorithm.scheduler)) + { + result.add_error("Unsupported trait combination"); + } + } + + void validate_lds(const KernelKey& key, ValidationResult& result) const + { + const auto& sig = key.signature; + const auto& alg = key.algorithm; + + float elem_a = element_size(sig.dtype_a); + float elem_b = element_size(sig.dtype_b); + + std::size_t matrix_a_size = alg.tile_shape.m * alg.tile_shape.k * elem_a; + std::size_t matrix_b_size = alg.tile_shape.n * alg.tile_shape.k * elem_b; + std::size_t total_lds = matrix_a_size + matrix_b_size; + + std::size_t max_lds = get_lds_capacity(alg.pipeline); + + if(total_lds > max_lds) + { + result.add_error("LDS capacity exceeded: " + std::to_string(total_lds) + " bytes > " + + std::to_string(max_lds) + " bytes limit"); + } + } + + GpuArch arch_; + bool strict_mode_; +}; + +// ============================================================================= +// Registry Integration Helper +// ============================================================================= + +/** + * Create a filter function for use with Registry::filter() + * + * @tparam KernelT Kernel instance type with get_key() method + * @param arch Target GPU architecture + * @return Predicate function that returns true for valid kernels + */ +template +inline auto make_arch_filter_predicate(const std::string& arch) +{ + return [filter = ArchFilter(arch)](const KernelT& kernel) { + return filter.is_valid(kernel.get_key()); + }; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp b/dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp new file mode 100644 index 0000000000..af52c8eb1d --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp @@ -0,0 +1,168 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + * + * Generated from: arch_specs.json + * Generated at: 2026-01-05T19:34:01.229811 + * + * To update this file: + * 1. Edit arch_specs.json + * 2. Run: python generate_arch_specs.py + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace arch_specs { + +// ============================================================================= +// GPU Architecture Enum (Generated) +// ============================================================================= + +enum class GpuArch : std::uint8_t +{ + GFX_908, // AMD Instinct MI100 + GFX_90A, // AMD Instinct MI200 series + GFX_942, // AMD Instinct MI300 series + GFX_950, // AMD Instinct MI350 series + GFX_1100, // AMD Radeon RX 7900 series (RDNA3) + GFX_1200, // AMD Radeon RX 9000 series (RDNA4) + GFX_1201, // AMD Radeon RX 9000 series (RDNA4) + UNKNOWN +}; + +// ============================================================================= +// String Conversion Functions (Generated) +// ============================================================================= + +inline std::string arch_to_string(GpuArch arch) +{ + switch(arch) + { + case GpuArch::GFX_908: return "gfx908"; + case GpuArch::GFX_90A: return "gfx90a"; + case GpuArch::GFX_942: return "gfx942"; + case GpuArch::GFX_950: return "gfx950"; + case GpuArch::GFX_1100: return "gfx1100"; + case GpuArch::GFX_1200: return "gfx1200"; + case GpuArch::GFX_1201: return "gfx1201"; + default: return "unknown"; + } +} + +inline GpuArch string_to_arch(const std::string& arch_str) +{ + if(arch_str == "gfx908") + return GpuArch::GFX_908; + if(arch_str == "gfx90a") + return GpuArch::GFX_90A; + if(arch_str == "gfx942") + return GpuArch::GFX_942; + if(arch_str == "gfx950") + return GpuArch::GFX_950; + if(arch_str == "gfx1100") + return GpuArch::GFX_1100; + if(arch_str == "gfx1200") + return GpuArch::GFX_1200; + if(arch_str == "gfx1201") + return GpuArch::GFX_1201; + return GpuArch::UNKNOWN; +} + +// ============================================================================= +// Element Size (Generated) +// ============================================================================= + +inline float element_size(DataType dtype) +{ + switch(dtype) + { + case DataType::FP16: return 2.0f; + case DataType::BF16: return 2.0f; + case DataType::FP32: return 4.0f; + case DataType::FP64: return 8.0f; + case DataType::FP8: return 1.0f; + case DataType::BF8: return 1.0f; + case DataType::INT8: return 1.0f; + case DataType::INT4: return 0.5f; + case DataType::INT32: return 4.0f; + default: return 2.0f; + } +} + +// ============================================================================= +// Warp Configurations (Generated) +// ============================================================================= + +using WarpConfig = std::array; + +inline std::vector get_supported_warp_configs(GpuArch arch) +{ + switch(arch) + { + case GpuArch::GFX_908: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_90A: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_942: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_950: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_1100: return {{2, 4, 1}, {1, 8, 1}, {8, 1, 1}, {4, 2, 1}}; + case GpuArch::GFX_1200: return {{2, 4, 1}, {1, 8, 1}, {8, 1, 1}, {4, 2, 1}}; + case GpuArch::GFX_1201: return {{2, 4, 1}, {1, 8, 1}, {8, 1, 1}, {4, 2, 1}}; + default: return {}; + } +} + +// ============================================================================= +// LDS Capacity Limits (Generated) +// ============================================================================= + +inline std::size_t get_lds_capacity(Pipeline pipeline) +{ + if(pipeline == Pipeline::Mem) + return 65536; + if(pipeline == Pipeline::CompV1) + return 65536; + if(pipeline == Pipeline::CompV2) + return 65536; + if(pipeline == Pipeline::CompV3) + return 65536; + if(pipeline == Pipeline::CompV4) + return 32768; + if(pipeline == Pipeline::CompV5) + return 65536; + if(pipeline == Pipeline::PreShuffleV1) + return 32768; + if(pipeline == Pipeline::PreShuffleV2) + return 32768; + return 65536; // Default +} + +// ============================================================================= +// Unsupported Trait Combinations (Generated) +// ============================================================================= + +inline bool +is_trait_unsupported(Pipeline pipeline, [[maybe_unused]] Epilogue epilogue, Scheduler scheduler) +{ + // Generated from unsupported_trait_combos in arch_specs.json + if(scheduler == Scheduler::Interwave) + { + if(pipeline == Pipeline::CompV3 || pipeline == Pipeline::CompV4) + { + return true; + } + } + return false; +} + +} // namespace arch_specs +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp new file mode 100644 index 0000000000..79f8f30a9b --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp @@ -0,0 +1,143 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Generated Kernel Backend + * + * Backend for kernels generated by unified_gemm_codegen.py + * with unique namespace wrapping (Kernel_{name}). + * + * Status: Work in progress - use generated_tile_backend.hpp for now + * + * This backend handles the new codegen format with unique kernel structs. + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/** + * Kernel instance wrapper for unified_gemm_codegen.py generated kernels + * + * These kernels have: + * - namespace {kernel_name}_ns { ... } (NEW format) + * - struct Kernel_{name} with static launch() method + * - struct SelectedKernel alias for compatibility + * - Type aliases: ADataType, BDataType, CDataType, AccDataType + * + * Note: Currently use generated_tile_backend.hpp for production + */ +template +class GeneratedKernelInstance : public KernelInstance +{ + public: + using SelectedKernel = SelectedKernelType; + using ADataType = typename SelectedKernel::ADataType; + using BDataType = typename SelectedKernel::BDataType; + using CDataType = typename SelectedKernel::CDataType; + using AccDataType = typename SelectedKernel::AccDataType; + + GeneratedKernelInstance(const KernelKey& key, const std::string& name) : key_(key), name_(name) + { + } + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + // Check dimension divisibility based on padding flags + constexpr bool pad_m = SelectedKernel::kPadM; + constexpr bool pad_n = SelectedKernel::kPadN; + constexpr bool pad_k = SelectedKernel::kPadK; + + if(pad_m && pad_n && pad_k) + { + return true; // Padding enabled - supports any size + } + + // Check divisibility for dimensions without padding + constexpr int tile_m = SelectedKernel::TileM; + constexpr int tile_n = SelectedKernel::TileN; + constexpr int tile_k = SelectedKernel::TileK; + + if(!pad_m && problem.M % tile_m != 0) + return false; + if(!pad_n && problem.N % tile_n != 0) + return false; + if(!pad_k && problem.K % tile_k != 0) + return false; + + return true; + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + (void)d_ptrs; // Not used in basic GEMM + + // Create arguments using constructor + ck_tile::GemmHostArgs args(a_ptr, // a_ptr + b_ptr, // b_ptr + c_ptr, // e_ptr/c_ptr + problem.k_batch, // k_batch + problem.M, // M + problem.N, // N + problem.K, // K + problem.K, // stride_A (row-major A: stride = K) + problem.K, // stride_B (column-major B: stride = K) + problem.N // stride_E/C (row-major C: stride = N) + ); + + // Create stream config for timing + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = reinterpret_cast(stream); + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = 0; + stream_cfg.cold_niters_ = 5; // Warmup iterations + stream_cfg.nrepeat_ = 10; // Measurement iterations + stream_cfg.is_gpu_timer_ = true; + stream_cfg.flush_cache_ = false; + stream_cfg.rotating_count_ = 1; + + // Call the generated kernel's launch method + return SelectedKernel::launch(args, stream_cfg); + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + (void)a_ptr; + (void)b_ptr; + (void)c_ptr; + (void)d_ptrs; + (void)problem; + (void)tolerance; + // Validation would require reference implementation + return true; + } + + private: + KernelKey key_; + std::string name_; +}; + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp new file mode 100644 index 0000000000..76565045cf --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp @@ -0,0 +1,157 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/validation/reference_kernels.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/** + * Kernel instance wrapper for unified_gemm_codegen.py generated kernels + * + * These kernels have structure: + * - Types defined outside: using ADataType = ...; using BDataType = ...; + * - struct SelectedKernel with static constexpr config and launch() method + * - constexpr const char* KERNEL_NAME = "..."; + * + * This is different from tile_engine style where everything is in SelectedKernel. + */ +template +class GeneratedTileKernelInstance : public KernelInstance +{ + public: + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + using AccDataType = AccDataType_; + using SelectedKernel = SelectedKernelType; + + GeneratedTileKernelInstance(const KernelKey& key, const std::string& name) + : key_(key), name_(name) + { + } + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + // Check dimension divisibility if padding not enabled + constexpr bool pad_m = SelectedKernel::kPadM; + constexpr bool pad_n = SelectedKernel::kPadN; + constexpr bool pad_k = SelectedKernel::kPadK; + + if(pad_m && pad_n && pad_k) + { + return true; // Padding enabled - supports any size + } + + // Check divisibility + constexpr int tile_m = SelectedKernel::TileM; + constexpr int tile_n = SelectedKernel::TileN; + constexpr int tile_k = SelectedKernel::TileK; + + if(!pad_m && problem.M % tile_m != 0) + return false; + if(!pad_n && problem.N % tile_n != 0) + return false; + if(!pad_k && problem.K % tile_k != 0) + return false; + + return true; + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + (void)d_ptrs; // Not used in basic GEMM + + // Create arguments using constructor (correct order!) + // Order from GemmHostArgs constructor: a_ptr, b_ptr, e_ptr, k_batch, M, N, K, stride_A, + // stride_B, stride_E + ck_tile::GemmHostArgs args(a_ptr, // a_ptr + b_ptr, // b_ptr + c_ptr, // e_ptr/c_ptr + problem.k_batch, // k_batch (4th argument!) + problem.M, // M + problem.N, // N + problem.K, // K + problem.K, // stride_A (row-major A: stride = K) + problem.K, // stride_B (column-major B: stride = K) + problem.N // stride_E/C (row-major C: stride = N) + ); + + // Create stream config for timing + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = reinterpret_cast(stream); + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = 0; // No logging for performance + stream_cfg.cold_niters_ = 5; // Warmup iterations + stream_cfg.nrepeat_ = 10; // Measurement iterations + stream_cfg.is_gpu_timer_ = true; + stream_cfg.flush_cache_ = false; + stream_cfg.rotating_count_ = 1; + + // Call the generated kernel's launch method + return SelectedKernel::launch(args, stream_cfg); + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + (void)a_ptr; + (void)b_ptr; + (void)c_ptr; + (void)d_ptrs; + (void)problem; + (void)tolerance; + // Validation would require reference implementation + return true; + } + + private: + KernelKey key_; + std::string name_; +}; + +/// Helper function to create a generated tile kernel instance wrapper +template +std::shared_ptr create_generated_tile_kernel(const KernelKey& key, + const std::string& name) +{ + return std::make_shared< + GeneratedTileKernelInstance>( + key, name); +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp b/dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp new file mode 100644 index 0000000000..01ab1f5e52 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/// Helper to register a CK Tile generated kernel +/// This should be called from generated code for each kernel +template +void register_tile_kernel(Registry& registry, const std::string& kernel_name) +{ + // Extract metadata from SelectedKernel static members + KernelKey key; + + // Signature + key.signature.dtype_a = static_cast(SelectedKernel::ADataType); + key.signature.dtype_b = static_cast(SelectedKernel::BDataType); + key.signature.dtype_c = static_cast(SelectedKernel::CDataType); + key.signature.dtype_acc = static_cast(SelectedKernel::AccDataType); + + key.signature.layout_a = static_cast(SelectedKernel::ALayout); + key.signature.layout_b = static_cast(SelectedKernel::BLayout); + key.signature.layout_c = static_cast(SelectedKernel::CLayout); + + key.signature.transpose_a = false; // Extract from kernel if available + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + + key.signature.elementwise_op = "PassThrough"; // Extract if available + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = SelectedKernel::UseStructuredSparsity; + + // Algorithm + key.algorithm.tile_shape.m = SelectedKernel::TileM; + key.algorithm.tile_shape.n = SelectedKernel::TileN; + key.algorithm.tile_shape.k = SelectedKernel::TileK; + + key.algorithm.wave_shape.m = SelectedKernel::WarpPerBlock_M; + key.algorithm.wave_shape.n = SelectedKernel::WarpPerBlock_N; + key.algorithm.wave_shape.k = SelectedKernel::WarpPerBlock_K; + + key.algorithm.warp_tile_shape.m = SelectedKernel::WarpTileM; + key.algorithm.warp_tile_shape.n = SelectedKernel::WarpTileN; + key.algorithm.warp_tile_shape.k = SelectedKernel::WarpTileK; + + // Extract pipeline, epilogue, scheduler from traits + key.algorithm.pipeline = Pipeline::CompV4; // Extract from kernel + key.algorithm.epilogue = Epilogue::Default; // Extract from kernel + key.algorithm.scheduler = Scheduler::Auto; // Extract from kernel + + key.algorithm.block_size = SelectedKernel::BlockSize; + key.algorithm.double_buffer = SelectedKernel::DoubleSmemBuffer; + key.algorithm.persistent = SelectedKernel::UsePersistentKernel; + key.algorithm.preshuffle = false; // Extract if available + key.algorithm.transpose_c = SelectedKernel::TransposeC; + key.algorithm.num_wave_groups = 1; // Extract if available + + key.gfx_arch = 942; // Extract from build configuration + + // Create kernel instance + auto kernel_instance = std::make_shared>(key, kernel_name); + + // Register with high priority (Tile kernels preferred) + registry.register_kernel(kernel_instance, Registry::Priority::High); +} + +/// Macro to simplify kernel registration in generated code +#define CK_TILE_REGISTER_KERNEL(SelectedKernel, KernelName, Registry) \ + ::ck_tile::dispatcher::backends::register_tile_kernel(Registry, KernelName) + +/// Helper to register multiple kernels from a list +template +struct KernelRegistrar +{ + static void register_all(Registry& registry) + { + // This would be specialized for each kernel set + // For now, empty implementation + } +}; + +/// Auto-registration helper +/// Place this in generated files to automatically register kernels +template +struct AutoRegister +{ + AutoRegister(const std::string& kernel_name) + { + auto& registry = Registry::instance(); + register_tile_kernel(registry, kernel_name); + } +}; + +/// Macro for auto-registration +#define CK_TILE_AUTO_REGISTER(SelectedKernel, KernelName) \ + static ::ck_tile::dispatcher::backends::AutoRegister \ + auto_register_##SelectedKernel{KernelName}; + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp new file mode 100644 index 0000000000..a3a0b04685 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp @@ -0,0 +1,173 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/validation/reference_kernels.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/// Kernel instance for CK Tile generated kernels +template +class TileKernelInstance : public KernelInstance +{ + public: + TileKernelInstance(const KernelKey& key, const std::string& name) : key_(key), name_(name) {} + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + // Check dimension divisibility if padding not enabled + constexpr bool pad_m = SelectedKernel::kPadM; + constexpr bool pad_n = SelectedKernel::kPadN; + constexpr bool pad_k = SelectedKernel::kPadK; + + if(pad_m && pad_n && pad_k) + { + // Padding enabled - supports any size + return true; + } + + // Check divisibility + constexpr int tile_m = SelectedKernel::TileM; + constexpr int tile_n = SelectedKernel::TileN; + constexpr int tile_k = SelectedKernel::TileK; + + if(!pad_m && problem.M % tile_m != 0) + return false; + if(!pad_n && problem.N % tile_n != 0) + return false; + if(!pad_k && problem.K % tile_k != 0) + return false; + + // Check shared memory budget if specified + if(problem.smem_budget > 0) + { + int64_t estimated_smem = estimate_smem_usage(); + if(estimated_smem > problem.smem_budget) + return false; + } + + return true; + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + // Convert void* stream to hipStream_t + hipStream_t hip_stream = reinterpret_cast(stream); + + // Construct kernel arguments + using ADataType = typename SelectedKernel::ADataType; + using BDataType = typename SelectedKernel::BDataType; + using CDataType = typename SelectedKernel::CDataType; + + // Note: d_ptrs not yet supported in basic CK Tile kernels + (void)d_ptrs; // Suppress unused parameter warning + + auto kargs = SelectedKernel::MakeKernelArgs(static_cast(a_ptr), + static_cast(b_ptr), + static_cast(c_ptr), + problem.M, + problem.N, + problem.K, + problem.k_batch); + + // Validate arguments + if(!SelectedKernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel does not support the given arguments"); + } + + // Calculate grid and block dimensions + dim3 grids = SelectedKernel::GridSize(problem.M, problem.N, problem.K); + dim3 blocks = SelectedKernel::BlockSize(); + size_t lds_bytes = SelectedKernel::GetSmemSize(); + + // Time kernel execution + hipEvent_t start, stop; + (void)hipEventCreate(&start); + (void)hipEventCreate(&stop); + + (void)hipEventRecord(start, hip_stream); + + // Launch kernel + ck_tile::launch_kernel(SelectedKernel::Kernel, grids, blocks, lds_bytes, hip_stream, kargs); + + (void)hipEventRecord(stop, hip_stream); + (void)hipEventSynchronize(stop); + + float elapsed_ms = 0.0f; + (void)hipEventElapsedTime(&elapsed_ms, start, stop); + + (void)hipEventDestroy(start); + (void)hipEventDestroy(stop); + + return elapsed_ms; + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + // Use validation helper + using ADataType = typename SelectedKernel::ADataType; + using BDataType = typename SelectedKernel::BDataType; + using CDataType = typename SelectedKernel::CDataType; + using AccDataType = typename SelectedKernel::AccDataType; + + // d_ptrs not yet supported + (void)d_ptrs; + + // Convert tolerance to rtol and atol + float rtol = tolerance; + float atol = tolerance * 1e-2f; // atol is typically smaller + + return validation::validate_gemm_kernel( + a_ptr, b_ptr, c_ptr, problem, rtol, atol); + } + + private: + int64_t estimate_smem_usage() const + { + // Use kernel's reported shared memory size + return SelectedKernel::GetSmemSize(); + } + + KernelKey key_; + std::string name_; +}; + +/// Helper function to create a tile kernel instance wrapper +/// This should be called from generated code that knows the SelectedKernel type +template +std::shared_ptr create_tile_kernel_instance(const KernelKey& key, + const std::string& name) +{ + return std::make_shared>(key, name); +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp new file mode 100644 index 0000000000..6d3f548138 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp @@ -0,0 +1,146 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Dispatcher - Main Kernel Selection and Execution Engine + * + * The Dispatcher provides unified interface for selecting and executing + * CK Tile GEMM kernels based on problem specifications. + * + * Features: + * - Multiple selection strategies (FirstFit, Heuristic) + * - Custom heuristic functions + * - Thread-safe registry integration + * - Real GPU execution with timing + * + * Usage: + * Dispatcher dispatcher; + * Problem problem(M, N, K); + * float time = dispatcher.run(a_dev, b_dev, c_dev, problem); + * + * Status: Production ready - 319 TFLOPS validated + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Heuristic function type: maps Problem to ordered list of kernel identifiers +/// Returns kernel identifiers ranked by expected performance (best first) +using HeuristicFunction = std::function(const Problem&)>; + +/// Dispatcher: Top-level orchestration for kernel selection and execution +/// Provides unified interface for kernel dispatch across different backends +class Dispatcher +{ + public: + /// Selection strategy for kernel choice + enum class SelectionStrategy + { + FirstFit, // Use first kernel that supports the problem + Heuristic // Use heuristic function to guide selection + }; + + /// Constructor + /// @param registry Registry instance to use (default: global singleton) + explicit Dispatcher(Registry* registry = nullptr); + + /// Register a heuristic function for kernel selection + /// @param heuristic Function that maps problems to ranked kernel identifiers + void set_heuristic(HeuristicFunction heuristic); + + /// Set selection strategy + /// @param strategy Strategy to use for kernel selection + void set_strategy(SelectionStrategy strategy); + + /// Select a kernel for the given problem + /// @param problem Problem configuration + /// @return Selected kernel instance, or nullptr if no suitable kernel found + [[nodiscard]] KernelInstancePtr select_kernel(const Problem& problem) const; + + /// Execute GEMM operation with automatic kernel selection + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds + /// @throws std::runtime_error if no suitable kernel found + [[nodiscard]] float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const Problem& problem, + void* stream = nullptr) const; + + /// Execute GEMM operation with fusion (multi-D) + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds + /// @throws std::runtime_error if no suitable kernel found + [[nodiscard]] float run_fused(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream = nullptr) const; + + /// Execute with explicit kernel selection + /// @param kernel_id Kernel identifier string + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds + /// @throws std::runtime_error if kernel not found or doesn't support problem + [[nodiscard]] float run_explicit(const std::string& kernel_id, + const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream = nullptr) const; + + /// Validate kernel output + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, kernel output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param tolerance Relative error tolerance + /// @return true if validation passes, false otherwise + [[nodiscard]] bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance = 1e-3f) const; + + private: + Registry* registry_; + HeuristicFunction heuristic_; + SelectionStrategy strategy_; + + /// Select kernel using first-fit strategy + [[nodiscard]] KernelInstancePtr select_first_fit(const Problem& problem) const; + + /// Select kernel using heuristic strategy + [[nodiscard]] KernelInstancePtr select_heuristic(const Problem& problem) const; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/example_args.hpp b/dispatcher/include/ck_tile/dispatcher/example_args.hpp new file mode 100644 index 0000000000..f93a4d61f6 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/example_args.hpp @@ -0,0 +1,230 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace utils { + +/** + * Simple command-line argument parser for examples. + * + * Usage: + * ExampleArgs args("Example 01: Basic GEMM", "Demonstrates basic GEMM usage"); + * args.add_flag("--list", "List all kernel sets"); + * args.add_option("--dtype", "fp16", "Data type (fp16, bf16, fp32)"); + * args.add_option("--size", "1024", "Problem size MxNxK"); + * + * if (!args.parse(argc, argv)) return 0; // --help was printed + * + * bool do_list = args.has("--list"); + * std::string dtype = args.get("--dtype"); + * int size = args.get_int("--size"); + */ +class ExampleArgs +{ + public: + ExampleArgs(const std::string& name, const std::string& description = "") + : name_(name), description_(description) + { + // Always add --help + add_flag("--help", "Show this help message"); + add_flag("-h", "Show this help message"); + } + + // Add a boolean flag (no value) + void add_flag(const std::string& name, const std::string& help) + { + flags_[name] = false; + help_[name] = help; + order_.push_back(name); + } + + // Add an option with a default value + void + add_option(const std::string& name, const std::string& default_val, const std::string& help) + { + options_[name] = default_val; + defaults_[name] = default_val; + help_[name] = help; + order_.push_back(name); + } + + // Parse arguments. Returns false if --help was requested. + bool parse(int argc, char* argv[]) + { + for(int i = 1; i < argc; ++i) + { + std::string arg = argv[i]; + + // Check for --help + if(arg == "--help" || arg == "-h") + { + print_help(); + return false; + } + + // Check for flags + if(flags_.find(arg) != flags_.end()) + { + flags_[arg] = true; + continue; + } + + // Check for options (--name=value or --name value) + std::string name, value; + size_t eq_pos = arg.find('='); + if(eq_pos != std::string::npos) + { + name = arg.substr(0, eq_pos); + value = arg.substr(eq_pos + 1); + } + else if(options_.find(arg) != options_.end() && i + 1 < argc) + { + name = arg; + value = argv[++i]; + } + else + { + // Positional argument - store as _pos_N + std::string pos_name = "_pos_" + std::to_string(positional_.size()); + positional_.push_back(arg); + continue; + } + + if(options_.find(name) != options_.end()) + { + options_[name] = value; + } + } + return true; + } + + // Check if a flag is set + bool has(const std::string& name) const + { + auto it = flags_.find(name); + return it != flags_.end() && it->second; + } + + // Get an option value as string + std::string get(const std::string& name) const + { + auto it = options_.find(name); + return it != options_.end() ? it->second : ""; + } + + // Get an option value as string with default + std::string get(const std::string& name, const std::string& default_val) const + { + auto it = options_.find(name); + return it != options_.end() ? it->second : default_val; + } + + // Get an option value as int + int get_int(const std::string& name, int default_val = 0) const + { + std::string val = get(name); + if(val.empty()) + return default_val; + try + { + return std::stoi(val); + } + catch(...) + { + return default_val; + } + } + + // Get an option value as float + float get_float(const std::string& name, float default_val = 0.0f) const + { + std::string val = get(name); + if(val.empty()) + return default_val; + try + { + return std::stof(val); + } + catch(...) + { + return default_val; + } + } + + // Get positional arguments + const std::vector& positional() const { return positional_; } + + // Print help message + void print_help() const + { + std::cout << "\n"; + std::cout << " " << name_ << "\n"; + if(!description_.empty()) + { + std::cout << " " << description_ << "\n"; + } + std::cout << "\n"; + std::cout << "Usage:\n"; + std::cout << " ./example [OPTIONS]\n"; + std::cout << "\n"; + std::cout << "Options:\n"; + + // Find max option name length for alignment + size_t max_len = 0; + for(const auto& name : order_) + { + if(name == "-h") + continue; // Skip -h, show --help only + max_len = std::max(max_len, name.length()); + } + + // Print options in order + for(const auto& name : order_) + { + if(name == "-h") + continue; + + std::cout << " " << std::left << std::setw(max_len + 2) << name; + + auto help_it = help_.find(name); + if(help_it != help_.end()) + { + std::cout << help_it->second; + } + + // Show default value for options + auto def_it = defaults_.find(name); + if(def_it != defaults_.end() && !def_it->second.empty()) + { + std::cout << " (default: " << def_it->second << ")"; + } + + std::cout << "\n"; + } + std::cout << "\n"; + } + + private: + std::string name_; + std::string description_; + std::map flags_; + std::map options_; + std::map defaults_; + std::map help_; + std::vector order_; + std::vector positional_; +}; + +} // namespace utils +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/json_export.hpp b/dispatcher/include/ck_tile/dispatcher/json_export.hpp new file mode 100644 index 0000000000..ab1c45412f --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/json_export.hpp @@ -0,0 +1,370 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * JSON Export Utilities for Dispatcher Registry + * + * Provides functionality to export kernel registry metadata to JSON format, + * similar to the tile engine benchmarking JSON export. + * + * Features: + * - Export all registered kernels with full metadata + * - Include kernel configuration (tile shapes, pipeline, scheduler, etc.) + * - Group kernels by various properties (data type, layout, pipeline, etc.) + * - Export to string or file + * + * Usage: + * auto& registry = Registry::instance(); + * std::string json = export_registry_json(registry); + * // or + * export_registry_json_to_file(registry, "kernels.json"); + */ + +#pragma once + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Convert DataType enum to string +inline std::string datatype_to_string(DataType dtype) +{ + switch(dtype) + { + case DataType::FP16: return "fp16"; + case DataType::BF16: return "bf16"; + case DataType::FP32: return "fp32"; + case DataType::FP8: return "fp8"; + case DataType::BF8: return "bf8"; + case DataType::INT8: return "int8"; + case DataType::INT32: return "int32"; + default: return "unknown"; + } +} + +/// Convert LayoutTag enum to string +inline std::string layout_to_string(LayoutTag layout) +{ + switch(layout) + { + case LayoutTag::RowMajor: return "row_major"; + case LayoutTag::ColMajor: return "col_major"; + case LayoutTag::PackedExternal: return "packed_external"; + default: return "unknown"; + } +} + +/// Convert Pipeline enum to string +inline std::string pipeline_to_string(Pipeline pipeline) +{ + switch(pipeline) + { + case Pipeline::Mem: return "mem"; + case Pipeline::CompV1: return "compv1"; + case Pipeline::CompV2: return "compv2"; + case Pipeline::CompV3: return "compv3"; + case Pipeline::CompV4: return "compv4"; + case Pipeline::CompV5: return "compv5"; + default: return "unknown"; + } +} + +/// Convert Epilogue enum to string +inline std::string epilogue_to_string(Epilogue epilogue) +{ + switch(epilogue) + { + case Epilogue::None: return "none"; + case Epilogue::Bias: return "bias"; + case Epilogue::Activation: return "activation"; + case Epilogue::CShuffle: return "cshuffle"; + case Epilogue::Default: return "default"; + default: return "unknown"; + } +} + +/// Convert Scheduler enum to string +inline std::string scheduler_to_string(Scheduler scheduler) +{ + switch(scheduler) + { + case Scheduler::Auto: return "auto"; + case Scheduler::Intrawave: return "intrawave"; + case Scheduler::Interwave: return "interwave"; + default: return "unknown"; + } +} + +/// Escape string for JSON +inline std::string json_escape(const std::string& str) +{ + std::ostringstream oss; + for(char c : str) + { + switch(c) + { + case '"': oss << "\\\""; break; + case '\\': oss << "\\\\"; break; + case '\b': oss << "\\b"; break; + case '\f': oss << "\\f"; break; + case '\n': oss << "\\n"; break; + case '\r': oss << "\\r"; break; + case '\t': oss << "\\t"; break; + default: + if(c < 0x20) + { + oss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (int)c; + } + else + { + oss << c; + } + } + } + return oss.str(); +} + +/// Get current timestamp in ISO 8601 format +inline std::string get_iso_timestamp() +{ + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + std::tm tm_buf; + localtime_r(&time_t, &tm_buf); + + std::ostringstream oss; + oss << std::put_time(&tm_buf, "%Y-%m-%dT%H:%M:%S"); + return oss.str(); +} + +/// Export a single kernel's metadata to JSON +inline std::string export_kernel_json(const KernelInstance& kernel) +{ + std::ostringstream json; + const auto& key = kernel.get_key(); + + json << " {\n"; + json << " \"name\": \"" << json_escape(kernel.get_name()) << "\",\n"; + json << " \"identifier\": \"" << json_escape(key.encode_identifier()) << "\",\n"; + + // Signature (what operation is computed) + json << " \"signature\": {\n"; + json << " \"dtype_a\": \"" << datatype_to_string(key.signature.dtype_a) << "\",\n"; + json << " \"dtype_b\": \"" << datatype_to_string(key.signature.dtype_b) << "\",\n"; + json << " \"dtype_c\": \"" << datatype_to_string(key.signature.dtype_c) << "\",\n"; + json << " \"dtype_acc\": \"" << datatype_to_string(key.signature.dtype_acc) << "\",\n"; + json << " \"layout_a\": \"" << layout_to_string(key.signature.layout_a) << "\",\n"; + json << " \"layout_b\": \"" << layout_to_string(key.signature.layout_b) << "\",\n"; + json << " \"layout_c\": \"" << layout_to_string(key.signature.layout_c) << "\",\n"; + json << " \"transpose_a\": " << (key.signature.transpose_a ? "true" : "false") << ",\n"; + json << " \"transpose_b\": " << (key.signature.transpose_b ? "true" : "false") << ",\n"; + json << " \"grouped\": " << (key.signature.grouped ? "true" : "false") << ",\n"; + json << " \"split_k\": " << (int)key.signature.split_k << ",\n"; + json << " \"elementwise_op\": \"" << json_escape(key.signature.elementwise_op) + << "\",\n"; + json << " \"num_d_tensors\": " << (int)key.signature.num_d_tensors << ",\n"; + json << " \"structured_sparsity\": " + << (key.signature.structured_sparsity ? "true" : "false") << "\n"; + json << " },\n"; + + // Algorithm (how it's implemented) + json << " \"algorithm\": {\n"; + json << " \"tile_shape\": {\n"; + json << " \"m\": " << key.algorithm.tile_shape.m << ",\n"; + json << " \"n\": " << key.algorithm.tile_shape.n << ",\n"; + json << " \"k\": " << key.algorithm.tile_shape.k << "\n"; + json << " },\n"; + json << " \"wave_shape\": {\n"; + json << " \"m\": " << (int)key.algorithm.wave_shape.m << ",\n"; + json << " \"n\": " << (int)key.algorithm.wave_shape.n << ",\n"; + json << " \"k\": " << (int)key.algorithm.wave_shape.k << "\n"; + json << " },\n"; + json << " \"warp_tile_shape\": {\n"; + json << " \"m\": " << (int)key.algorithm.warp_tile_shape.m << ",\n"; + json << " \"n\": " << (int)key.algorithm.warp_tile_shape.n << ",\n"; + json << " \"k\": " << (int)key.algorithm.warp_tile_shape.k << "\n"; + json << " },\n"; + json << " \"pipeline\": \"" << pipeline_to_string(key.algorithm.pipeline) << "\",\n"; + json << " \"scheduler\": \"" << scheduler_to_string(key.algorithm.scheduler) << "\",\n"; + json << " \"epilogue\": \"" << epilogue_to_string(key.algorithm.epilogue) << "\",\n"; + json << " \"block_size\": " << key.algorithm.block_size << ",\n"; + json << " \"double_buffer\": " << (key.algorithm.double_buffer ? "true" : "false") + << ",\n"; + json << " \"persistent\": " << (key.algorithm.persistent ? "true" : "false") << ",\n"; + json << " \"preshuffle\": " << (key.algorithm.preshuffle ? "true" : "false") << ",\n"; + json << " \"transpose_c\": " << (key.algorithm.transpose_c ? "true" : "false") << ",\n"; + json << " \"num_wave_groups\": " << (int)key.algorithm.num_wave_groups << "\n"; + json << " },\n"; + + json << " \"gfx_arch\": \"" << json_escape(key.gfx_arch) << "\"\n"; + json << " }"; + + return json.str(); +} + +/// Export registry metadata and statistics to JSON +inline std::string export_registry_json(const Registry& registry, bool include_statistics = true) +{ + std::ostringstream json; + + auto all_kernels = registry.get_all(); + + json << "{\n"; + + // Metadata + json << " \"metadata\": {\n"; + json << " \"timestamp\": \"" << get_iso_timestamp() << "\",\n"; + json << " \"registry_name\": \"" << json_escape(registry.get_name()) << "\",\n"; + json << " \"total_kernels\": " << all_kernels.size() << ",\n"; + json << " \"export_version\": \"1.0.0\"\n"; + json << " },\n"; + + // Statistics (if enabled) + if(include_statistics && !all_kernels.empty()) + { + std::map by_datatype; + std::map by_pipeline; + std::map by_scheduler; + std::map by_layout; + std::map by_gfx_arch; + + for(const auto& kernel : all_kernels) + { + const auto& key = kernel->get_key(); + + // Count by data type + std::string dtype_key = datatype_to_string(key.signature.dtype_a) + "_" + + datatype_to_string(key.signature.dtype_b) + "_" + + datatype_to_string(key.signature.dtype_c); + by_datatype[dtype_key]++; + + // Count by pipeline + by_pipeline[pipeline_to_string(key.algorithm.pipeline)]++; + + // Count by scheduler + by_scheduler[scheduler_to_string(key.algorithm.scheduler)]++; + + // Count by layout + std::string layout_key = layout_to_string(key.signature.layout_a) + "_" + + layout_to_string(key.signature.layout_b) + "_" + + layout_to_string(key.signature.layout_c); + by_layout[layout_key]++; + + // Count by GFX architecture + by_gfx_arch[key.gfx_arch]++; + } + + json << " \"statistics\": {\n"; + + // Data type breakdown + json << " \"by_datatype\": {\n"; + bool first = true; + for(const auto& [dtype, count] : by_datatype) + { + if(!first) + json << ",\n"; + json << " \"" << dtype << "\": " << count; + first = false; + } + json << "\n },\n"; + + // Pipeline breakdown + json << " \"by_pipeline\": {\n"; + first = true; + for(const auto& [pipeline, count] : by_pipeline) + { + if(!first) + json << ",\n"; + json << " \"" << pipeline << "\": " << count; + first = false; + } + json << "\n },\n"; + + // Scheduler breakdown + json << " \"by_scheduler\": {\n"; + first = true; + for(const auto& [scheduler, count] : by_scheduler) + { + if(!first) + json << ",\n"; + json << " \"" << scheduler << "\": " << count; + first = false; + } + json << "\n },\n"; + + // Layout breakdown + json << " \"by_layout\": {\n"; + first = true; + for(const auto& [layout, count] : by_layout) + { + if(!first) + json << ",\n"; + json << " \"" << layout << "\": " << count; + first = false; + } + json << "\n },\n"; + + // GFX architecture breakdown + json << " \"by_gfx_arch\": {\n"; + first = true; + for(const auto& [arch, count] : by_gfx_arch) + { + if(!first) + json << ",\n"; + json << " \"" << arch << "\": " << count; + first = false; + } + json << "\n }\n"; + + json << " },\n"; + } + + // Kernels list + json << " \"kernels\": [\n"; + for(size_t i = 0; i < all_kernels.size(); ++i) + { + json << export_kernel_json(*all_kernels[i]); + if(i < all_kernels.size() - 1) + { + json << ","; + } + json << "\n"; + } + json << " ]\n"; + + json << "}\n"; + + return json.str(); +} + +/// Export registry to a JSON file +inline bool export_registry_json_to_file(const Registry& registry, + const std::string& filename, + bool include_statistics = true) +{ + std::string json = export_registry_json(registry, include_statistics); + + std::ofstream file(filename); + if(!file.is_open()) + { + return false; + } + + file << json; + file.close(); + + return true; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_config.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_config.hpp new file mode 100644 index 0000000000..05011d2c2d --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_config.hpp @@ -0,0 +1,370 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file kernel_config.hpp + * @brief Explicit kernel configuration for CK Tile Dispatcher + * + * This header provides a KernelConfig struct that mirrors the Python API, + * allowing explicit, self-contained kernel configuration without relying + * on force-included generated headers. + * + * Usage: + * #include "ck_tile/dispatcher/kernel_config.hpp" + * using namespace ck_tile::dispatcher; + * + * // Step 1: Define explicit config + * auto config = KernelConfig::fp16_rcr() + * .tile(128, 128, 32) + * .wave(2, 2, 1) + * .warp_tile(32, 32, 16) + * .pipeline(Pipeline::CompV4) + * .scheduler(Scheduler::Intrawave); + * + * // Step 2: Create registry and register + * Registry registry; + * registry.register_kernel(config.build_key(), config.get_name()); + * + * // Step 3: Create dispatcher + * Dispatcher dispatcher(®istry); + * + * // Step 4: Run GEMM + * dispatcher.run(a, b, c, Problem(M, N, K)); + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/** + * @brief Explicit kernel configuration matching Python's KernelConfig + * + * This provides a fluent builder API for creating kernel configurations + * with all parameters visible and explicit. + */ +class KernelConfig +{ + public: + // ========================================================================= + // Data types + // ========================================================================= + DataType dtype_a = DataType::FP16; + DataType dtype_b = DataType::FP16; + DataType dtype_c = DataType::FP16; + DataType dtype_acc = DataType::FP32; + + // ========================================================================= + // Layouts + // ========================================================================= + LayoutTag layout_a = LayoutTag::RowMajor; + LayoutTag layout_b = LayoutTag::ColMajor; + LayoutTag layout_c = LayoutTag::RowMajor; + + // ========================================================================= + // Tile shape + // ========================================================================= + int tile_m = 128; + int tile_n = 128; + int tile_k = 32; + + // ========================================================================= + // Wave shape (warps per block) + // ========================================================================= + int wave_m = 2; + int wave_n = 2; + int wave_k = 1; + + // ========================================================================= + // Warp tile shape + // ========================================================================= + int warp_m = 32; + int warp_n = 32; + int warp_k = 16; + + // ========================================================================= + // Block and pipeline + // ========================================================================= + int block_size = 256; + Pipeline pipeline_type = Pipeline::CompV4; + Scheduler scheduler_type = Scheduler::Intrawave; + Epilogue epilogue_type = Epilogue::CShuffle; + + // ========================================================================= + // Padding and features + // ========================================================================= + bool pad_m = true; + bool pad_n = true; + bool pad_k = true; + bool preshuffle = false; + + // ========================================================================= + // Target architecture + // ========================================================================= + std::string gfx_arch = "gfx942"; + + // ========================================================================= + // Fluent builder methods + // ========================================================================= + + /// Set tile dimensions (M x N x K) + KernelConfig& tile(int m, int n, int k) + { + tile_m = m; + tile_n = n; + tile_k = k; + return *this; + } + + /// Set wave dimensions (warps per block M x N x K) + KernelConfig& wave(int m, int n, int k) + { + wave_m = m; + wave_n = n; + wave_k = k; + return *this; + } + + /// Set warp tile dimensions (M x N x K) + KernelConfig& warp_tile(int m, int n, int k) + { + warp_m = m; + warp_n = n; + warp_k = k; + return *this; + } + + /// Set block size + KernelConfig& block(int size) + { + block_size = size; + return *this; + } + + /// Set pipeline type + KernelConfig& pipeline(Pipeline p) + { + pipeline_type = p; + return *this; + } + + /// Set scheduler type + KernelConfig& scheduler(Scheduler s) + { + scheduler_type = s; + return *this; + } + + /// Set epilogue type + KernelConfig& epilogue(Epilogue e) + { + epilogue_type = e; + return *this; + } + + /// Set data types for A, B, C + KernelConfig& dtypes(DataType a, DataType b, DataType c, DataType acc = DataType::FP32) + { + dtype_a = a; + dtype_b = b; + dtype_c = c; + dtype_acc = acc; + return *this; + } + + /// Set layouts for A, B, C + KernelConfig& layouts(LayoutTag a, LayoutTag b, LayoutTag c) + { + layout_a = a; + layout_b = b; + layout_c = c; + return *this; + } + + /// Set padding flags + KernelConfig& padding(bool m, bool n, bool k) + { + pad_m = m; + pad_n = n; + pad_k = k; + return *this; + } + + /// Set target GPU architecture + KernelConfig& arch(const std::string& gpu) + { + gfx_arch = gpu; + return *this; + } + + // ========================================================================= + // Preset configurations + // ========================================================================= + + /// FP16 Row-Column-Row layout (most common) + static KernelConfig fp16_rcr() { return KernelConfig{}; } + + /// FP16 Row-Row-Row layout + static KernelConfig fp16_rrr() + { + KernelConfig cfg; + cfg.layout_b = LayoutTag::RowMajor; + return cfg; + } + + /// BF16 Row-Column-Row layout + static KernelConfig bf16_rcr() + { + KernelConfig cfg; + cfg.dtype_a = DataType::BF16; + cfg.dtype_b = DataType::BF16; + cfg.dtype_c = DataType::BF16; + return cfg; + } + + /// FP32 Row-Column-Row layout + static KernelConfig fp32_rcr() + { + KernelConfig cfg; + cfg.dtype_a = DataType::FP32; + cfg.dtype_b = DataType::FP32; + cfg.dtype_c = DataType::FP32; + cfg.dtype_acc = DataType::FP32; + return cfg; + } + + // ========================================================================= + // Build KernelKey + // ========================================================================= + + /// Build a KernelKey from this configuration + [[nodiscard]] KernelKey build_key() const + { + KernelKey key; + + // Signature + key.signature.dtype_a = dtype_a; + key.signature.dtype_b = dtype_b; + key.signature.dtype_c = dtype_c; + key.signature.dtype_acc = dtype_acc; + key.signature.layout_a = layout_a; + key.signature.layout_b = layout_b; + key.signature.layout_c = layout_c; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + // Algorithm + key.algorithm.tile_shape = {static_cast(tile_m), + static_cast(tile_n), + static_cast(tile_k)}; + key.algorithm.wave_shape = {static_cast(wave_m), + static_cast(wave_n), + static_cast(wave_k)}; + key.algorithm.warp_tile_shape = {static_cast(warp_m), + static_cast(warp_n), + static_cast(warp_k)}; + key.algorithm.pipeline = pipeline_type; + key.algorithm.scheduler = scheduler_type; + key.algorithm.epilogue = epilogue_type; + key.algorithm.block_size = block_size; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = preshuffle; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + + key.gfx_arch = gfx_arch; + + return key; + } + + // ========================================================================= + // String representations + // ========================================================================= + + /// Get tile string (e.g., "128x128x32") + [[nodiscard]] std::string tile_str() const + { + std::ostringstream oss; + oss << tile_m << "x" << tile_n << "x" << tile_k; + return oss.str(); + } + + /// Get wave string (e.g., "2x2x1") + [[nodiscard]] std::string wave_str() const + { + std::ostringstream oss; + oss << wave_m << "x" << wave_n << "x" << wave_k; + return oss.str(); + } + + /// Get warp tile string (e.g., "32x32x16") + [[nodiscard]] std::string warp_tile_str() const + { + std::ostringstream oss; + oss << warp_m << "x" << warp_n << "x" << warp_k; + return oss.str(); + } + + /// Get layout string (e.g., "rcr") + [[nodiscard]] std::string layout_str() const + { + std::ostringstream oss; + oss << to_string(layout_a) << to_string(layout_b) << to_string(layout_c); + return oss.str(); + } + + /// Get kernel name for generated code lookup + [[nodiscard]] std::string get_name() const + { + std::ostringstream oss; + oss << "gemm_" << to_string(dtype_a) << "_" << layout_str() << "_" + << to_string(pipeline_type) << "_" << to_string(epilogue_type) << "_" + << to_string(scheduler_type) << "_" << (pad_m ? "True" : "False") << "_" + << (pad_n ? "True" : "False") << "_" << (pad_k ? "True" : "False") << "_" + << "False" // preshuffle + << "_" << tile_str() << "_" << wave_str() << "_" << warp_tile_str(); + return oss.str(); + } + + /// Print configuration to stdout + void print_config(std::ostream& os = std::cout) const + { + os << " Data types:\n"; + os << " dtype_a = " << to_string(dtype_a) << "\n"; + os << " dtype_b = " << to_string(dtype_b) << "\n"; + os << " dtype_c = " << to_string(dtype_c) << "\n"; + os << " dtype_acc = " << to_string(dtype_acc) << "\n"; + os << " Layouts:\n"; + os << " layout_a = " << to_string(layout_a) << "\n"; + os << " layout_b = " << to_string(layout_b) << "\n"; + os << " layout_c = " << to_string(layout_c) << "\n"; + os << " Tile shape:\n"; + os << " tile = " << tile_str() << "\n"; + os << " wave = " << wave_str() << "\n"; + os << " warp_tile = " << warp_tile_str() << "\n"; + os << " Pipeline:\n"; + os << " pipeline = " << to_string(pipeline_type) << "\n"; + os << " scheduler = " << to_string(scheduler_type) << "\n"; + os << " epilogue = " << to_string(epilogue_type) << "\n"; + os << " Padding:\n"; + os << " pad_m = " << (pad_m ? "true" : "false") << "\n"; + os << " pad_n = " << (pad_n ? "true" : "false") << "\n"; + os << " pad_k = " << (pad_k ? "true" : "false") << "\n"; + os << " Target:\n"; + os << " gfx_arch = " << gfx_arch << "\n"; + } +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp new file mode 100644 index 0000000000..095de52e06 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp @@ -0,0 +1,509 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file kernel_decl.hpp + * @brief Declarative kernel specification with KernelSet + * + * USAGE: + * ====== + * + * // Named kernel sets + * DECL_KERNEL_SET(compute_bound, + * .add("fp16", "rcr", 256, 256, 64) + * .add("fp16", "rcr", 128, 128, 32) + * ); + * + * // Access at runtime + * auto& set = KernelSetRegistry::instance().get("compute_bound"); + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace decl { + +// ============================================================================= +// Wildcard constants +// ============================================================================= + +constexpr const char* ANY = "*"; +constexpr int ANY_INT = -1; + +// ============================================================================= +// Signature Builder +// ============================================================================= + +class Signature +{ + public: + std::string dtype_a_ = "fp16"; + std::string dtype_b_ = "fp16"; + std::string dtype_c_ = "fp16"; + std::string dtype_acc_ = "fp32"; + std::string layout_a_ = "row"; + std::string layout_b_ = "col"; + std::string layout_c_ = "row"; + std::string elementwise_op_ = "PassThrough"; + int num_d_tensors_ = 0; + bool structured_sparsity_ = false; + + Signature& dtype(const std::string& a, + const std::string& b, + const std::string& c, + const std::string& acc = "fp32") + { + dtype_a_ = a; + dtype_b_ = b; + dtype_c_ = c; + dtype_acc_ = acc; + return *this; + } + + Signature& dtype(const std::string& all) + { + dtype_a_ = dtype_b_ = dtype_c_ = all; + dtype_acc_ = "fp32"; + return *this; + } + + Signature& layout(const std::string& a, const std::string& b, const std::string& c) + { + layout_a_ = a; + layout_b_ = b; + layout_c_ = c; + return *this; + } + + Signature& layout(const std::string& combined) + { + if(combined.size() >= 3) + { + layout_a_ = (combined[0] == 'r') ? "row" : "col"; + layout_b_ = (combined[1] == 'r') ? "row" : "col"; + layout_c_ = (combined[2] == 'r') ? "row" : "col"; + } + return *this; + } + + Signature& elementwise(const std::string& op, int num_d = 0) + { + elementwise_op_ = op; + num_d_tensors_ = num_d; + return *this; + } + + std::string layout_str() const + { + std::string r; + r += (layout_a_ == "col") ? 'c' : 'r'; + r += (layout_b_ == "col") ? 'c' : 'r'; + r += (layout_c_ == "col") ? 'c' : 'r'; + return r; + } +}; + +// ============================================================================= +// Algorithm Builder +// ============================================================================= + +class Algorithm +{ + public: + int tile_m_ = 128, tile_n_ = 128, tile_k_ = 32; + int wave_m_ = ANY_INT, wave_n_ = ANY_INT, wave_k_ = 1; + int warp_m_ = ANY_INT, warp_n_ = ANY_INT, warp_k_ = 16; + std::string pipeline_ = "compv4"; + std::string scheduler_ = "intrawave"; + std::string epilogue_ = "cshuffle"; + int block_size_ = 256; + int pad_m_ = 1, pad_n_ = 1, pad_k_ = 1; + bool preshuffle_ = false; + + Algorithm& tile(int m, int n, int k) + { + tile_m_ = m; + tile_n_ = n; + tile_k_ = k; + return *this; + } + + Algorithm& wave(int m, int n, int k = 1) + { + wave_m_ = m; + wave_n_ = n; + wave_k_ = k; + return *this; + } + + Algorithm& warp(int m, int n, int k = 16) + { + warp_m_ = m; + warp_n_ = n; + warp_k_ = k; + return *this; + } + + Algorithm& pipeline(const std::string& p) + { + pipeline_ = p; + return *this; + } + Algorithm& scheduler(const std::string& s) + { + scheduler_ = s; + return *this; + } + Algorithm& epilogue(const std::string& e) + { + epilogue_ = e; + return *this; + } + + Algorithm& pad(bool m, bool n, bool k) + { + pad_m_ = m ? 1 : 0; + pad_n_ = n ? 1 : 0; + pad_k_ = k ? 1 : 0; + return *this; + } + + Algorithm& preshuffle(bool v) + { + preshuffle_ = v; + return *this; + } + + bool needs_expansion() const + { + return wave_m_ == ANY_INT || warp_m_ == ANY_INT || pipeline_ == "*" || pad_m_ == ANY_INT; + } + + void auto_fill() + { + if(wave_m_ == ANY_INT) + wave_m_ = 2; + if(wave_n_ == ANY_INT) + wave_n_ = 2; + if(wave_k_ == ANY_INT) + wave_k_ = 1; + if(warp_m_ == ANY_INT) + warp_m_ = 32; + if(warp_n_ == ANY_INT) + warp_n_ = 32; + if(warp_k_ == ANY_INT) + warp_k_ = 16; + } +}; + +// ============================================================================= +// Kernel Declaration +// ============================================================================= + +struct KernelDecl +{ + Signature signature; + Algorithm algorithm; + std::string arch = "gfx942"; + + KernelDecl() = default; + + KernelDecl(const Signature& sig, const Algorithm& algo, const std::string& a = "gfx942") + : signature(sig), algorithm(algo), arch(a) + { + } + + std::string name() const + { + std::ostringstream oss; + oss << signature.dtype_a_ << "_" << signature.layout_str(); + if(algorithm.tile_m_ > 0) + { + oss << "_" << algorithm.tile_m_ << "x" << algorithm.tile_n_ << "x" << algorithm.tile_k_; + } + return oss.str(); + } + + bool has_wildcards() const { return algorithm.needs_expansion() || arch == "*"; } +}; + +// ============================================================================= +// KernelSet - Collection of declarations +// ============================================================================= + +class KernelSet +{ + public: + KernelSet() = default; + + KernelSet& add(const Signature& sig, const Algorithm& algo, const std::string& arch = "gfx942") + { + decls_.emplace_back(sig, algo, arch); + return *this; + } + + KernelSet& add(const std::string& dtype, + const std::string& layout, + int tm, + int tn, + int tk, + const std::string& arch = "gfx942") + { + Signature sig; + sig.dtype(dtype).layout(layout); + Algorithm algo; + algo.tile(tm, tn, tk); + decls_.emplace_back(sig, algo, arch); + return *this; + } + + KernelSet& add(const KernelDecl& decl) + { + decls_.push_back(decl); + return *this; + } + + KernelSet& merge(const KernelSet& other) + { + decls_.insert(decls_.end(), other.decls_.begin(), other.decls_.end()); + return *this; + } + + const std::vector& declarations() const { return decls_; } + size_t size() const { return decls_.size(); } + + bool needs_expansion() const + { + for(const auto& d : decls_) + { + if(d.algorithm.needs_expansion()) + return true; + } + return false; + } + + void print(std::ostream& os = std::cout) const + { + os << "KernelSet (" << size() << " declarations):\n"; + for(const auto& d : decls_) + { + os << " - " << d.name(); + if(d.algorithm.needs_expansion()) + os << " [expands]"; + os << "\n"; + } + } + + KernelSet& tag(const std::string& t) + { + tag_ = t; + return *this; + } + std::string tag() const { return tag_; } + + private: + std::vector decls_; + std::string tag_; +}; + +// ============================================================================= +// KernelSet Registry +// ============================================================================= + +class KernelSetRegistry +{ + public: + static KernelSetRegistry& instance() + { + static KernelSetRegistry reg; + return reg; + } + + void add(const std::string& name, const KernelSet& set) + { + sets_[name] = set; + order_.push_back(name); + } + + const KernelSet& get(const std::string& name) const + { + static KernelSet empty; + auto it = sets_.find(name); + return it != sets_.end() ? it->second : empty; + } + + bool has(const std::string& name) const { return sets_.find(name) != sets_.end(); } + + // Return const reference to avoid deep copy + const std::vector& names() const { return order_; } + size_t size() const { return sets_.size(); } + + void print() const + { + std::cout << "Named Kernel Sets (" << size() << "):\n"; + for(const auto& name : order_) + { + const auto& set = sets_.at(name); + std::cout << " " << name << ": " << set.size() << " declarations\n"; + } + } + + private: + KernelSetRegistry() = default; + std::unordered_map sets_; + std::vector order_; +}; + +// ============================================================================= +// Declaration Registry (for DECL_KERNEL) +// ============================================================================= + +class Registry +{ + public: + static Registry& instance() + { + static Registry reg; + return reg; + } + + void add(const KernelDecl& decl) + { + std::string key = decl.has_wildcards() + ? ("wildcard_" + std::to_string(declarations_.size())) + : decl.name(); + declarations_[key] = decl; + order_.push_back(key); + } + + std::vector all() const + { + std::vector result; + for(const auto& key : order_) + { + result.push_back(declarations_.at(key)); + } + return result; + } + + size_t size() const { return declarations_.size(); } + + void print() const + { + std::cout << "Declared kernels (" << size() << "):\n"; + for(const auto& key : order_) + { + const auto& d = declarations_.at(key); + std::cout << " " << d.name(); + if(d.has_wildcards()) + std::cout << " [wildcards]"; + std::cout << "\n"; + } + } + + private: + Registry() = default; + std::unordered_map declarations_; + std::vector order_; +}; + +// ============================================================================= +// Static Registrars +// ============================================================================= + +struct Declarator +{ + Declarator(const Signature& sig, const Algorithm& algo, const std::string& arch = "gfx942") + { + Registry::instance().add(KernelDecl(sig, algo, arch)); + } + + Declarator(const std::string& dtype, + const std::string& layout, + int tm, + int tn, + int tk, + const std::string& arch = "gfx942") + { + Signature sig; + sig.dtype(dtype).layout(layout); + Algorithm algo; + algo.tile(tm, tn, tk); + Registry::instance().add(KernelDecl(sig, algo, arch)); + } + + Declarator(const std::string& dtype, const std::string& layout, const std::string& arch) + { + Signature sig; + sig.dtype(dtype).layout(layout); + Algorithm algo; + algo.tile(ANY_INT, ANY_INT, ANY_INT); + Registry::instance().add(KernelDecl(sig, algo, arch)); + } +}; + +struct KernelSetRegistrar +{ + KernelSetRegistrar(const std::string& name, const KernelSet& set) + { + KernelSetRegistry::instance().add(name, set); + } +}; + +} // namespace decl + +// ============================================================================= +// Convenience Aliases +// ============================================================================= + +using KernelSignature = decl::Signature; +using KernelAlgorithm = decl::Algorithm; +using KernelDecl = decl::KernelDecl; +using KernelDeclRegistry = decl::Registry; +using KernelSet = decl::KernelSet; +using KernelSetRegistry = decl::KernelSetRegistry; + +constexpr const char* ANY = decl::ANY; +constexpr int ANY_INT = decl::ANY_INT; + +} // namespace dispatcher +} // namespace ck_tile + +// ============================================================================= +// Declaration Macros +// ============================================================================= + +#define CK_DECL_CAT_(a, b) CK_DECL_CAT_IMPL_(a, b) +#define CK_DECL_CAT_IMPL_(a, b) a##b + +// Note: __extension__ suppresses warnings about __COUNTER__ being a GCC/Clang extension +#define DECL_KERNEL(sig, algo, ...) \ + __extension__ static ::ck_tile::dispatcher::decl::Declarator CK_DECL_CAT_( \ + _kdecl_, __COUNTER__)(sig, algo, ##__VA_ARGS__) + +#define DECL_KERNEL_SIMPLE(dtype, layout, tm, tn, tk) \ + __extension__ static ::ck_tile::dispatcher::decl::Declarator CK_DECL_CAT_( \ + _kdecl_, __COUNTER__)(#dtype, #layout, tm, tn, tk) + +#define DECL_KERNEL_ALL(dtype, layout) \ + __extension__ static ::ck_tile::dispatcher::decl::Declarator CK_DECL_CAT_( \ + _kdecl_, __COUNTER__)(#dtype, #layout, "*") + +#define DECL_KERNEL_SET(name, ...) \ + __extension__ static ::ck_tile::dispatcher::decl::KernelSetRegistrar CK_DECL_CAT_( \ + _kset_reg_, __COUNTER__)(#name, \ + ::ck_tile::dispatcher::decl::KernelSet() __VA_ARGS__.tag(#name)) + +#define KERNEL_SET(name) ::ck_tile::dispatcher::decl::KernelSet name +#define BEGIN_KERNEL_SET() ::ck_tile::dispatcher::decl::KernelSet() + +// Legacy compatibility +// Legacy aliases removed - use DECL_KERNEL_SET instead diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp new file mode 100644 index 0000000000..4a734f4c3f --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp @@ -0,0 +1,68 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// KernelInstance: Uniform interface for kernel execution +/// Abstracts away implementation details (CK Library vs CK Tile vs future JIT) +/// Enables type-erased storage in registry while backends perform type-safe casts +class KernelInstance +{ + public: + virtual ~KernelInstance() = default; + + /// Get the kernel's configuration metadata + [[nodiscard]] virtual const KernelKey& get_key() const = 0; + + /// Check if this kernel supports the given problem + /// Returns false if problem dimensions don't meet kernel requirements + /// (e.g., divisibility constraints, resource limits) + [[nodiscard]] virtual bool supports(const Problem& problem) const = 0; + + /// Get human-readable kernel name for logging and debugging + [[nodiscard]] virtual std::string get_name() const = 0; + + /// Execute the kernel with given problem and data pointers + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param d_ptrs Array of pointers to additional D tensors for fusion (device memory) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds (0 if timing not available) + [[nodiscard]] virtual float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream = nullptr) const = 0; + + /// Validate kernel output against reference implementation + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, kernel output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param tolerance Relative error tolerance for validation + /// @return true if validation passes, false otherwise + [[nodiscard]] virtual bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance = 1e-3f) const = 0; +}; + +/// Shared pointer type for kernel instances +using KernelInstancePtr = std::shared_ptr; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp new file mode 100644 index 0000000000..f49b3a0d74 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp @@ -0,0 +1,428 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Data types supported by CK Tile GEMM kernels +/// Matches tile_engine DATA_TYPE_MAP for full compatibility +enum class DataType : std::uint8_t +{ + FP16, // ck_tile::half_t + BF16, // ck_tile::bf16_t + FP32, // float + FP64, // double + FP8, // ck_tile::fp8_t (E4M3) + BF8, // ck_tile::bf8_t (E5M2) + INT8, // ck_tile::int8_t + INT4, // ck_tile::pk_int4_t (packed int4) + INT32, // ck_tile::int32_t + UNKNOWN +}; + +/// Memory layout tags for tensors +enum class LayoutTag : std::uint8_t +{ + RowMajor, + ColMajor, + PackedExternal +}; + +/// Pipeline variants for memory/compute optimization +/// Matches tile_engine PIPELINE_MAP for full compatibility +enum class Pipeline : std::uint8_t +{ + Mem, // Memory-bound pipeline + CompV1, // Compute pipeline v1 + CompV2, // Compute pipeline v2 + CompV3, // Compute pipeline v3 + CompV4, // Compute pipeline v4 (double buffering) + CompV5, // Compute pipeline v5 + PreShuffleV1, // Weight preshuffle pipeline v1 + PreShuffleV2 // Weight preshuffle pipeline v2 (optimized) +}; + +/// Epilogue strategies for output processing +/// Matches tile_engine epilogue options for full compatibility +enum class Epilogue : std::uint8_t +{ + None, + Default, // DefaultGemm2DEpilogue + CShuffle, // CShuffleEpilogue (cross-shuffle) + Bias, // Bias addition + Activation, // Fused activation + BiasActivation // Fused bias + activation +}; + +/// Scheduler types for wave coordination +enum class Scheduler : std::uint8_t +{ + Auto, + Intrawave, + Interwave +}; + +/// KernelKey: Compile-time kernel configuration metadata +/// Organized into Signature (what operation) and Algorithm (how it's implemented) +struct KernelKey +{ + /// Signature: Describes WHAT operation is computed (mathematical semantics) + /// Two kernels with different signatures compute different mathematical operations + struct Signature + { + DataType dtype_a; + DataType dtype_b; + DataType dtype_c; + DataType dtype_acc; + LayoutTag layout_a; + LayoutTag layout_b; + LayoutTag layout_c; + bool transpose_a; + bool transpose_b; + bool grouped; + std::uint8_t split_k; + + // Element-wise fusion: Describes mathematical operation applied to GEMM output + // Examples: PassThrough (C = A*B), MultiDAdd (E = C + D0 + D1), + // MultiDMultiply (E = C * D0 * D1), Clamp, Relu, Gelu, etc. + // This affects the mathematical result, so it belongs in Signature + std::string elementwise_op; // e.g., "PassThrough", "MultiDAdd", "Relu" + std::uint8_t + num_d_tensors; // Number of additional input tensors for fusion (0 for basic GEMM) + + bool structured_sparsity; // 2:4 sparsity affects mathematical correctness + } signature; + + /// Algorithm: Describes HOW it's implemented (performance tuning parameters) + /// Two kernels with same signature but different algorithms compute the same result + /// with different performance characteristics + struct Algorithm + { + // Hierarchical tiling configuration (primary tuning knobs) + struct TileShape + { + std::uint16_t m; + std::uint16_t n; + std::uint16_t k; + } tile_shape; + + struct WaveShape + { + std::uint8_t m; // WarpPerBlock_M in generated kernels + std::uint8_t n; // WarpPerBlock_N + std::uint8_t k; // WarpPerBlock_K + } wave_shape; + + struct WarpTileShape + { + std::uint8_t m; // WarpTileM in generated kernels + std::uint8_t n; // WarpTileN + std::uint8_t k; // WarpTileK + } warp_tile_shape; + + // Pipeline and scheduling strategy + Pipeline pipeline; + Scheduler scheduler; + Epilogue epilogue; + + // Block and memory configuration + std::uint16_t block_size; // BlockSize in generated kernels (typically 256) + bool double_buffer; // DoubleSmemBuffer (true for compv4) + bool persistent; // UsePersistentKernel + bool preshuffle; // Preshuffle (for weight preshuffle variants) + bool transpose_c; // TransposeC + std::uint8_t num_wave_groups; // NumWaveGroups + } algorithm; + + std::string gfx_arch; // e.g. "gfx942", "gfx90a", "gfx908" + + /// Generate a unique string identifier for this kernel configuration + /// Format matches tile_engine naming convention for registry lookup + /// Note: Defined after to_string() functions to use them + [[nodiscard]] std::string encode_identifier() const; + + /// Create a tuple of all fields for comparison operators + auto tie() const + { + return std::tie(signature.dtype_a, + signature.dtype_b, + signature.dtype_c, + signature.dtype_acc, + signature.layout_a, + signature.layout_b, + signature.layout_c, + signature.transpose_a, + signature.transpose_b, + signature.grouped, + signature.split_k, + signature.elementwise_op, + signature.num_d_tensors, + signature.structured_sparsity, + algorithm.tile_shape.m, + algorithm.tile_shape.n, + algorithm.tile_shape.k, + algorithm.wave_shape.m, + algorithm.wave_shape.n, + algorithm.wave_shape.k, + algorithm.warp_tile_shape.m, + algorithm.warp_tile_shape.n, + algorithm.warp_tile_shape.k, + algorithm.pipeline, + algorithm.epilogue, + algorithm.scheduler, + algorithm.block_size, + gfx_arch, + signature.structured_sparsity, + algorithm.persistent, + algorithm.double_buffer, + algorithm.preshuffle, + algorithm.transpose_c, + algorithm.num_wave_groups); + } + + /// Equality comparison + friend bool operator==(const KernelKey& lhs, const KernelKey& rhs) + { + return lhs.tie() == rhs.tie(); + } + + /// Inequality comparison + friend bool operator!=(const KernelKey& lhs, const KernelKey& rhs) { return !(lhs == rhs); } +}; + +// ============================================================================= +// String Conversion Helpers (for serialization and debugging) +// ============================================================================= + +/// Convert DataType to string +inline std::string to_string(DataType dtype) +{ + switch(dtype) + { + case DataType::FP16: return "fp16"; + case DataType::BF16: return "bf16"; + case DataType::FP32: return "fp32"; + case DataType::FP64: return "fp64"; + case DataType::FP8: return "fp8"; + case DataType::BF8: return "bf8"; + case DataType::INT8: return "int8"; + case DataType::INT4: return "int4"; + case DataType::INT32: return "int32"; + default: return "unknown"; + } +} + +/// Convert string to DataType +inline DataType string_to_dtype(const std::string& str) +{ + if(str == "fp16") + return DataType::FP16; + if(str == "bf16") + return DataType::BF16; + if(str == "fp32") + return DataType::FP32; + if(str == "fp64") + return DataType::FP64; + if(str == "fp8") + return DataType::FP8; + if(str == "bf8") + return DataType::BF8; + if(str == "int8") + return DataType::INT8; + if(str == "int4") + return DataType::INT4; + if(str == "int32") + return DataType::INT32; + return DataType::UNKNOWN; +} + +/// Convert LayoutTag to string +inline std::string to_string(LayoutTag layout) +{ + switch(layout) + { + case LayoutTag::RowMajor: return "r"; + case LayoutTag::ColMajor: return "c"; + case LayoutTag::PackedExternal: return "p"; + default: return "?"; + } +} + +/// Convert string to LayoutTag +inline LayoutTag string_to_layout(const std::string& str) +{ + if(str == "r" || str == "row" || str == "RowMajor") + return LayoutTag::RowMajor; + if(str == "c" || str == "col" || str == "ColMajor") + return LayoutTag::ColMajor; + if(str == "p" || str == "packed") + return LayoutTag::PackedExternal; + return LayoutTag::RowMajor; // Default +} + +/// Convert Pipeline to string +inline std::string to_string(Pipeline pipeline) +{ + switch(pipeline) + { + case Pipeline::Mem: return "mem"; + case Pipeline::CompV1: return "compv1"; + case Pipeline::CompV2: return "compv2"; + case Pipeline::CompV3: return "compv3"; + case Pipeline::CompV4: return "compv4"; + case Pipeline::CompV5: return "compv5"; + case Pipeline::PreShuffleV1: return "preshufflev1"; + case Pipeline::PreShuffleV2: return "preshufflev2"; + default: return "unknown"; + } +} + +/// Convert string to Pipeline +inline Pipeline string_to_pipeline(const std::string& str) +{ + if(str == "mem") + return Pipeline::Mem; + if(str == "compv1") + return Pipeline::CompV1; + if(str == "compv2") + return Pipeline::CompV2; + if(str == "compv3") + return Pipeline::CompV3; + if(str == "compv4") + return Pipeline::CompV4; + if(str == "compv5") + return Pipeline::CompV5; + if(str == "preshufflev1") + return Pipeline::PreShuffleV1; + if(str == "preshufflev2") + return Pipeline::PreShuffleV2; + return Pipeline::Mem; // Default +} + +/// Convert Epilogue to string +inline std::string to_string(Epilogue epilogue) +{ + switch(epilogue) + { + case Epilogue::None: return "none"; + case Epilogue::Default: return "default"; + case Epilogue::CShuffle: return "cshuffle"; + case Epilogue::Bias: return "bias"; + case Epilogue::Activation: return "activation"; + case Epilogue::BiasActivation: return "bias_activation"; + default: return "unknown"; + } +} + +/// Convert string to Epilogue +inline Epilogue string_to_epilogue(const std::string& str) +{ + if(str == "none") + return Epilogue::None; + if(str == "default") + return Epilogue::Default; + if(str == "cshuffle") + return Epilogue::CShuffle; + if(str == "bias") + return Epilogue::Bias; + if(str == "activation") + return Epilogue::Activation; + if(str == "bias_activation") + return Epilogue::BiasActivation; + return Epilogue::Default; // Default +} + +/// Convert Scheduler to string +inline std::string to_string(Scheduler scheduler) +{ + switch(scheduler) + { + case Scheduler::Auto: return "auto"; + case Scheduler::Intrawave: return "intrawave"; + case Scheduler::Interwave: return "interwave"; + default: return "unknown"; + } +} + +/// Convert string to Scheduler +inline Scheduler string_to_scheduler(const std::string& str) +{ + if(str == "auto") + return Scheduler::Auto; + if(str == "intrawave") + return Scheduler::Intrawave; + if(str == "interwave") + return Scheduler::Interwave; + return Scheduler::Intrawave; // Default +} + +/// Common elementwise operations (for reference in elementwise_op field) +/// These match CK Tile's ck_tile::element_wise namespace +namespace ElementwiseOps { +constexpr const char* PassThrough = "PassThrough"; +constexpr const char* Add = "Add"; +constexpr const char* Multiply = "Multiply"; +constexpr const char* MultiDAdd = "MultiDAdd"; +constexpr const char* MultiDMultiply = "MultiDMultiply"; +constexpr const char* Relu = "Relu"; +constexpr const char* Gelu = "Gelu"; +constexpr const char* Clamp = "Clamp"; +constexpr const char* Sigmoid = "Sigmoid"; +constexpr const char* Tanh = "Tanh"; +constexpr const char* Swish = "Swish"; +constexpr const char* HardSwish = "HardSwish"; +} // namespace ElementwiseOps + +// ============================================================================= +// KernelKey::encode_identifier() implementation +// Defined after to_string() functions to use them +// ============================================================================= + +inline std::string KernelKey::encode_identifier() const +{ + std::ostringstream oss; + + // Include data types and layout for uniqueness across different signatures + oss << to_string(signature.dtype_a) << "_"; + oss << to_string(signature.layout_a) << to_string(signature.layout_b) + << to_string(signature.layout_c) << "_"; + + // Include pipeline, scheduler, epilogue for uniqueness + oss << to_string(algorithm.pipeline) << "_"; + oss << to_string(algorithm.scheduler) << "_"; + oss << to_string(algorithm.epilogue) << "_"; + + // Match tile_engine naming: tile_m x tile_n x tile_k _ warp_m x warp_n x warp_k _ + // warp_tile_m x warp_tile_n x warp_tile_k + oss << algorithm.tile_shape.m << "x" << algorithm.tile_shape.n << "x" << algorithm.tile_shape.k + << "_" << unsigned(algorithm.wave_shape.m) << "x" << unsigned(algorithm.wave_shape.n) << "x" + << unsigned(algorithm.wave_shape.k) << "_" << unsigned(algorithm.warp_tile_shape.m) << "x" + << unsigned(algorithm.warp_tile_shape.n) << "x" << unsigned(algorithm.warp_tile_shape.k); + + // Add trait flags + oss << "_" << (algorithm.persistent ? "persist" : "nopers"); + + if(signature.split_k > 1) + oss << "_splitk" << unsigned(signature.split_k); + if(!signature.elementwise_op.empty() && signature.elementwise_op != "PassThrough") + oss << "_" << signature.elementwise_op; + if(signature.num_d_tensors > 0) + oss << "_d" << unsigned(signature.num_d_tensors); + if(signature.structured_sparsity) + oss << "_sparse"; + if(algorithm.preshuffle) + oss << "_preshuffle"; + + return oss.str(); +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/problem.hpp b/dispatcher/include/ck_tile/dispatcher/problem.hpp new file mode 100644 index 0000000000..437511d1ba --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/problem.hpp @@ -0,0 +1,311 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +// ============================================================================= +// Tensor Information for Automatic MNK Inference +// ============================================================================= + +/// TensorShape: Describes tensor dimensions for automatic MNK inference +struct TensorShape +{ + std::int64_t rows; // First dimension + std::int64_t cols; // Second dimension + bool is_transposed; // Whether the tensor is transposed (column-major) + + TensorShape() : rows(0), cols(0), is_transposed(false) {} + TensorShape(std::int64_t r, std::int64_t c, bool trans = false) + : rows(r), cols(c), is_transposed(trans) + { + } + + /// Get logical M (rows when not transposed) + [[nodiscard]] std::int64_t logical_rows() const { return is_transposed ? cols : rows; } + + /// Get logical N (cols when not transposed) + [[nodiscard]] std::int64_t logical_cols() const { return is_transposed ? rows : cols; } +}; + +// ============================================================================= +// Problem: Runtime Parameters +// ============================================================================= + +/// Problem: Runtime parameters for kernel invocation +/// Captures problem dimensions and resource constraints that vary between invocations +/// even when using the same kernel +struct Problem +{ + // Problem dimensions + std::int64_t M; // Number of rows in A and C + std::int64_t N; // Number of columns in B and C + std::int64_t K; // Shared dimension (columns of A, rows of B) + + // Batch configuration + std::int32_t k_batch; // Number of K-dimension splits for split-K GEMM + + // Resource preferences + std::int32_t smem_budget; // Shared memory budget in bytes (0 = no constraint) + bool prefer_persistent; // Prefer persistent kernel variants + + // Validation control + bool enable_validation; // Enable output validation against reference + + /// Default constructor with sensible defaults + Problem() + : M(0), + N(0), + K(0), + k_batch(1), + smem_budget(0), + prefer_persistent(false), + enable_validation(false) + { + } + + /// Constructor with problem dimensions + Problem(std::int64_t m, std::int64_t n, std::int64_t k) + : M(m), + N(n), + K(k), + k_batch(1), + smem_budget(0), + prefer_persistent(false), + enable_validation(false) + { + } + + /// Check if problem dimensions are valid + [[nodiscard]] bool is_valid() const { return M > 0 && N > 0 && K > 0 && k_batch > 0; } + + /// Get total number of operations (for performance metrics) + [[nodiscard]] std::int64_t num_ops() const + { + return 2 * M * N * K; // Multiply-add counts as 2 ops + } + + // ========================================================================= + // Factory Methods for Automatic MNK Inference + // ========================================================================= + + /** + * Create Problem by inferring MNK from tensor shapes. + * + * For GEMM: C[M,N] = A[M,K] × B[K,N] + * + * @param a_shape Shape of matrix A (M x K, or K x M if transposed) + * @param b_shape Shape of matrix B (K x N, or N x K if transposed) + * @param c_shape Shape of matrix C (M x N) - used for validation + * @throws std::invalid_argument if dimensions are inconsistent + * + * Example: + * // A is 512x256, B is 256x1024, C is 512x1024 + * auto problem = Problem::from_shapes({512, 256}, {256, 1024}, {512, 1024}); + * // Infers: M=512, N=1024, K=256 + */ + [[nodiscard]] static Problem + from_shapes(TensorShape a_shape, TensorShape b_shape, TensorShape c_shape) + { + // For C = A × B: + // A: [M, K] (or [K, M] if transposed) + // B: [K, N] (or [N, K] if transposed) + // C: [M, N] + + std::int64_t M_from_A = a_shape.logical_rows(); + std::int64_t K_from_A = a_shape.logical_cols(); + std::int64_t K_from_B = b_shape.logical_rows(); + std::int64_t N_from_B = b_shape.logical_cols(); + std::int64_t M_from_C = c_shape.logical_rows(); + std::int64_t N_from_C = c_shape.logical_cols(); + + // Validate K dimension matches between A and B + if(K_from_A != K_from_B) + { + throw std::invalid_argument( + "K dimension mismatch: A has K=" + std::to_string(K_from_A) + + ", B has K=" + std::to_string(K_from_B)); + } + + // Validate M dimension matches between A and C + if(M_from_A != M_from_C) + { + throw std::invalid_argument( + "M dimension mismatch: A has M=" + std::to_string(M_from_A) + + ", C has M=" + std::to_string(M_from_C)); + } + + // Validate N dimension matches between B and C + if(N_from_B != N_from_C) + { + throw std::invalid_argument( + "N dimension mismatch: B has N=" + std::to_string(N_from_B) + + ", C has N=" + std::to_string(N_from_C)); + } + + return Problem(M_from_A, N_from_B, K_from_A); + } + + /** + * Create Problem from tensor dimensions (simple version without transpose). + * + * @param a_rows Rows of matrix A (= M) + * @param a_cols Columns of matrix A (= K) + * @param b_rows Rows of matrix B (= K) + * @param b_cols Columns of matrix B (= N) + * @param c_rows Rows of matrix C (= M) - for validation + * @param c_cols Columns of matrix C (= N) - for validation + * @throws std::invalid_argument if dimensions are inconsistent + * + * Example: + * // A[512,256] × B[256,1024] = C[512,1024] + * auto problem = Problem::from_dimensions(512, 256, 256, 1024, 512, 1024); + */ + [[nodiscard]] static Problem from_dimensions(std::int64_t a_rows, + std::int64_t a_cols, + std::int64_t b_rows, + std::int64_t b_cols, + std::int64_t c_rows, + std::int64_t c_cols) + { + return from_shapes( + TensorShape(a_rows, a_cols), TensorShape(b_rows, b_cols), TensorShape(c_rows, c_cols)); + } + + /** + * Create Problem from A and B dimensions only (C is inferred). + * + * @param a_rows Rows of matrix A (= M) + * @param a_cols Columns of matrix A (= K) + * @param b_rows Rows of matrix B (= K) - validated + * @param b_cols Columns of matrix B (= N) + * @throws std::invalid_argument if K dimensions don't match + * + * Example: + * // A[512,256] × B[256,1024] = C[512,1024] + * auto problem = Problem::from_ab(512, 256, 256, 1024); + */ + [[nodiscard]] static Problem + from_ab(std::int64_t a_rows, std::int64_t a_cols, std::int64_t b_rows, std::int64_t b_cols) + { + if(a_cols != b_rows) + { + throw std::invalid_argument("K dimension mismatch: A.cols=" + std::to_string(a_cols) + + ", B.rows=" + std::to_string(b_rows)); + } + return Problem(a_rows, b_cols, a_cols); + } + + /** + * Validate that tensor pointers have consistent sizes. + * Call this before kernel execution to catch dimension errors early. + * + * @param a_size Total elements in A tensor + * @param b_size Total elements in B tensor + * @param c_size Total elements in C tensor + * @throws std::invalid_argument if sizes don't match expected dimensions + */ + void validate_sizes(std::int64_t a_size, std::int64_t b_size, std::int64_t c_size) const + { + std::int64_t expected_a = M * K; + std::int64_t expected_b = K * N; + std::int64_t expected_c = M * N; + + if(a_size != expected_a) + { + throw std::invalid_argument("A tensor size mismatch: got " + std::to_string(a_size) + + ", expected " + std::to_string(expected_a) + " (M*K = " + + std::to_string(M) + "*" + std::to_string(K) + ")"); + } + if(b_size != expected_b) + { + throw std::invalid_argument("B tensor size mismatch: got " + std::to_string(b_size) + + ", expected " + std::to_string(expected_b) + " (K*N = " + + std::to_string(K) + "*" + std::to_string(N) + ")"); + } + if(c_size != expected_c) + { + throw std::invalid_argument("C tensor size mismatch: got " + std::to_string(c_size) + + ", expected " + std::to_string(expected_c) + " (M*N = " + + std::to_string(M) + "*" + std::to_string(N) + ")"); + } + } +}; + +// ============================================================================= +// Convenience Builders +// ============================================================================= + +/// Builder pattern for Problem configuration +class ProblemBuilder +{ + public: + ProblemBuilder() = default; + + /// Set dimensions from A and B shapes + ProblemBuilder& + from_ab(std::int64_t a_rows, std::int64_t a_cols, std::int64_t b_rows, std::int64_t b_cols) + { + problem_ = Problem::from_ab(a_rows, a_cols, b_rows, b_cols); + return *this; + } + + /// Set MNK directly + ProblemBuilder& dimensions(std::int64_t m, std::int64_t n, std::int64_t k) + { + problem_.M = m; + problem_.N = n; + problem_.K = k; + return *this; + } + + /// Set split-K batch count + ProblemBuilder& split_k(std::int32_t k_batch) + { + problem_.k_batch = k_batch; + return *this; + } + + /// Set shared memory budget + ProblemBuilder& smem_budget(std::int32_t budget) + { + problem_.smem_budget = budget; + return *this; + } + + /// Prefer persistent kernels + ProblemBuilder& persistent(bool prefer = true) + { + problem_.prefer_persistent = prefer; + return *this; + } + + /// Enable validation + ProblemBuilder& validate(bool enable = true) + { + problem_.enable_validation = enable; + return *this; + } + + /// Build the Problem + [[nodiscard]] Problem build() const + { + if(!problem_.is_valid()) + { + throw std::invalid_argument("Invalid problem dimensions"); + } + return problem_; + } + + private: + Problem problem_; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/registry.hpp b/dispatcher/include/ck_tile/dispatcher/registry.hpp new file mode 100644 index 0000000000..93d1eb9f64 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/registry.hpp @@ -0,0 +1,197 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Registry - Thread-Safe Kernel Storage + * + * Central registry for all available kernel instances with priority-based + * ordering and efficient lookup. + * + * Features: + * - Thread-safe registration and lookup + * - Priority-based ordering (High, Normal, Low) + * - Lookup by name or KernelKey + * - Filter by problem compatibility + * - Supports both singleton and multiple instance patterns + * + * Usage (Singleton - backward compatible): + * auto& registry = Registry::instance(); + * registry.register_kernel(kernel, Priority::High); + * auto kernel = registry.lookup("kernel_name"); + * + * Usage (Multiple registries): + * Registry fp16_registry; + * Registry bf16_registry; + * fp16_registry.register_kernel(fp16_kernel, Priority::High); + * bf16_registry.register_kernel(bf16_kernel, Priority::High); + * + * Dispatcher fp16_dispatcher(&fp16_registry); + * Dispatcher bf16_dispatcher(&bf16_registry); + * + * Status: Production ready, thread-safe + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Registry: Central mapping from kernel configurations to executable instances +/// Thread-safe kernel registration and lookup +/// Supports both singleton pattern and multiple independent instances +class Registry +{ + public: + /// Priority levels for conflict resolution when multiple kernels have same key + enum class Priority + { + Low = 0, + Normal = 1, + High = 2 + }; + + /// Default constructor - creates an empty registry instance + /// Use this to create independent registries for different kernel sets + Registry(); + + /// Destructor - triggers auto-export if enabled + ~Registry(); + + /// Move constructor + Registry(Registry&& other) noexcept; + + /// Move assignment + Registry& operator=(Registry&& other) noexcept; + + // Prevent copying (registries contain shared_ptrs that shouldn't be duplicated) + Registry(const Registry&) = delete; + Registry& operator=(const Registry&) = delete; + + /// Register a kernel instance with the registry + /// @param instance Kernel instance to register + /// @param priority Priority level for conflict resolution (default: Normal) + /// @return true if registered successfully, false if duplicate with higher priority exists + bool register_kernel(KernelInstancePtr instance, Priority priority = Priority::Normal); + + /// Lookup a kernel by its string identifier + /// @param identifier Kernel identifier string + /// @return Kernel instance if found, nullptr otherwise + [[nodiscard]] KernelInstancePtr lookup(const std::string& identifier) const; + + /// Lookup a kernel by its KernelKey + /// @param key Kernel configuration key + /// @return Kernel instance if found, nullptr otherwise + [[nodiscard]] KernelInstancePtr lookup(const KernelKey& key) const; + + /// Get all registered kernels + /// @return Vector of all kernel instances + [[nodiscard]] std::vector get_all() const; + + /// Get all kernels matching a predicate + /// @param predicate Function to filter kernels + /// @return Vector of matching kernel instances + [[nodiscard]] std::vector + filter(std::function predicate) const; + + /// Get number of registered kernels + [[nodiscard]] std::size_t size() const; + + /// Check if registry is empty + [[nodiscard]] bool empty() const; + + /// Clear all registered kernels + void clear(); + + /// Get registry name (for logging/debugging) + [[nodiscard]] const std::string& get_name() const; + + /// Set registry name (for logging/debugging) + void set_name(const std::string& name); + + /// Export registry to JSON string + /// @param include_statistics Whether to include kernel statistics breakdown + /// @return JSON string with all kernel metadata + [[nodiscard]] std::string export_json(bool include_statistics = true) const; + + /// Export registry to JSON file + /// @param filename Output filename + /// @param include_statistics Whether to include kernel statistics breakdown + /// @return true if export succeeded, false otherwise + bool export_json_to_file(const std::string& filename, bool include_statistics = true) const; + + /// Enable automatic JSON export on kernel registration + /// @param filename Output filename for auto-export + /// @param include_statistics Whether to include statistics in auto-export + /// @param export_on_every_registration If true, exports after every registration (default). + /// If false, only exports on destruction. + void enable_auto_export(const std::string& filename, + bool include_statistics = true, + bool export_on_every_registration = true); + + /// Disable automatic JSON export + void disable_auto_export(); + + /// Check if auto-export is enabled + [[nodiscard]] bool is_auto_export_enabled() const; + + /// Merge kernels from another registry into this one + /// @param other Registry to merge from + /// @param priority Priority for merged kernels (default: Normal) + /// @return Number of kernels successfully merged + std::size_t merge_from(const Registry& other, Priority priority = Priority::Normal); + + /// Filter kernels in-place by architecture + /// @param gpu_arch Target GPU architecture string (e.g., "gfx942") + /// @return Number of kernels removed + std::size_t filter_by_arch(const std::string& gpu_arch); + + /// Get singleton instance of the global registry (backward compatible) + /// This is the default registry used when no specific registry is provided + static Registry& instance(); + + private: + struct RegistryEntry + { + KernelInstancePtr instance; + Priority priority; + }; + + /// Perform auto-export if enabled + void perform_auto_export(); + + mutable std::mutex mutex_; + std::unordered_map kernels_; + std::string name_; + + // Auto-export configuration + bool auto_export_enabled_ = false; + std::string auto_export_filename_; + bool auto_export_include_statistics_ = true; + bool auto_export_on_every_registration_ = true; +}; + +/// Shared pointer type for registries (useful for managing lifetime) +using RegistryPtr = std::shared_ptr; + +/// Create a new registry instance (factory function) +inline RegistryPtr make_registry(const std::string& name = "") +{ + auto reg = std::make_shared(); + if(!name.empty()) + { + reg->set_name(name); + } + return reg; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/utils.hpp b/dispatcher/include/ck_tile/dispatcher/utils.hpp new file mode 100644 index 0000000000..0f9990c45e --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/utils.hpp @@ -0,0 +1,724 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file utils.hpp + * @brief Common utilities for CK Tile Dispatcher + * + * This header provides reusable utilities for: + * - GPU memory management (GpuBuffer) + * - Performance measurement (Timer, GpuTimer, BenchmarkStats) + * - Validation (ValidationResult, validate_result) + * - Kernel registration helpers + * - Data generation (fill_random, etc.) + * + * Usage: + * #include "ck_tile/dispatcher/utils.hpp" + * using namespace ck_tile::dispatcher::utils; + * + * // GPU memory + * GpuBuffer buffer(1024); + * + * // Timing + * GpuTimer timer; + * timer.start(); + * // ... kernel ... + * timer.stop(); + * float ms = timer.elapsed_ms(); + * + * // Validation + * auto result = validate_result(gpu_data, ref_data, size); + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +namespace ck_tile { +namespace dispatcher { +namespace utils { + +// ============================================================================= +// HIP Error Handling +// ============================================================================= + +#define CK_HIP_CHECK(call) \ + do \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP error at " << __FILE__ << ":" << __LINE__ << ": " \ + << hipGetErrorString(err) << std::endl; \ + return false; \ + } \ + } while(0) + +#define CK_HIP_CHECK_THROW(call) \ + do \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + throw std::runtime_error(std::string("HIP error: ") + hipGetErrorString(err)); \ + } \ + } while(0) + +// ============================================================================= +// Timing Utilities +// ============================================================================= + +/** + * @brief High-resolution timer for CPU timing + */ +class Timer +{ + public: + void start() { start_ = std::chrono::high_resolution_clock::now(); } + + double elapsed_ms() const + { + auto end = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end - start_).count(); + } + + private: + std::chrono::high_resolution_clock::time_point start_; +}; + +/** + * @brief GPU timing using HIP events + * + * Times kernel execution on a specific HIP stream. Events are recorded + * on the provided stream to accurately measure kernel execution time. + * + * Usage: + * hipStream_t stream; + * hipStreamCreate(&stream); + * GpuTimer timer(stream); // or timer.set_stream(stream) + * timer.start(); + * kernel<<>>(...); + * timer.stop(); + * float ms = timer.elapsed_ms(); + */ +class GpuTimer +{ + public: + /** + * @brief Construct timer with optional stream + * @param stream HIP stream to record events on (default: null stream) + */ + explicit GpuTimer(hipStream_t stream = nullptr) : stream_(stream) + { + (void)hipEventCreate(&start_); + (void)hipEventCreate(&stop_); + } + + ~GpuTimer() + { + (void)hipEventDestroy(start_); + (void)hipEventDestroy(stop_); + } + + // Non-copyable + GpuTimer(const GpuTimer&) = delete; + GpuTimer& operator=(const GpuTimer&) = delete; + + // Movable + GpuTimer(GpuTimer&& other) noexcept + : start_(other.start_), stop_(other.stop_), stream_(other.stream_) + { + other.start_ = nullptr; + other.stop_ = nullptr; + other.stream_ = nullptr; + } + + GpuTimer& operator=(GpuTimer&& other) noexcept + { + if(this != &other) + { + if(start_) + (void)hipEventDestroy(start_); + if(stop_) + (void)hipEventDestroy(stop_); + start_ = other.start_; + stop_ = other.stop_; + stream_ = other.stream_; + other.start_ = nullptr; + other.stop_ = nullptr; + other.stream_ = nullptr; + } + return *this; + } + + /** + * @brief Set the stream to record events on + * @param stream HIP stream (pass nullptr for default stream) + */ + void set_stream(hipStream_t stream) { stream_ = stream; } + + /** + * @brief Get the current stream + */ + hipStream_t get_stream() const { return stream_; } + + /** + * @brief Record start event on the stream + */ + void start() { (void)hipEventRecord(start_, stream_); } + + /** + * @brief Record stop event on the stream + */ + void stop() { (void)hipEventRecord(stop_, stream_); } + + /** + * @brief Get elapsed time in milliseconds + * + * Synchronizes on the stop event before calculating time. + * @return Elapsed time between start and stop in milliseconds + */ + float elapsed_ms() + { + (void)hipEventSynchronize(stop_); + float ms = 0; + (void)hipEventElapsedTime(&ms, start_, stop_); + return ms; + } + + private: + hipEvent_t start_ = nullptr; + hipEvent_t stop_ = nullptr; + hipStream_t stream_ = nullptr; +}; + +// ============================================================================= +// Performance Metrics +// ============================================================================= + +/** + * @brief Calculate TFLOPS for GEMM + */ +inline double calculate_tflops(int64_t M, int64_t N, int64_t K, double time_ms) +{ + double flops = 2.0 * M * N * K; + return (flops / (time_ms * 1e-3)) / 1e12; +} + +/** + * @brief Calculate memory bandwidth in GB/s + */ +template +inline double calculate_bandwidth_gbs(int64_t M, int64_t N, int64_t K, double time_ms) +{ + double bytes = M * K * sizeof(AType) + K * N * sizeof(BType) + M * N * sizeof(CType); + return (bytes / (time_ms * 1e-3)) / 1e9; +} + +/** + * @brief Benchmark statistics + */ +struct BenchmarkStats +{ + double min_ms = 0; + double avg_ms = 0; + double max_ms = 0; + double median_ms = 0; + double tflops = 0; + double bandwidth_gbs = 0; + int iterations = 0; + + void print(std::ostream& os = std::cout) const + { + os << std::fixed << std::setprecision(4); + os << " Min: " << min_ms << " ms\n"; + os << " Avg: " << avg_ms << " ms\n"; + os << " Max: " << max_ms << " ms\n"; + os << " Median: " << median_ms << " ms\n"; + os << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + os << " Bandwidth: " << bandwidth_gbs << " GB/s\n"; + } +}; + +/** + * @brief Run benchmark and compute statistics + */ +template +BenchmarkStats run_benchmark(Func&& func, int warmup = 2, int iterations = 10) +{ + std::vector times; + times.reserve(iterations); + + for(int i = 0; i < warmup; ++i) + func(); + + for(int i = 0; i < iterations; ++i) + times.push_back(func()); + + std::sort(times.begin(), times.end()); + + BenchmarkStats stats; + stats.iterations = iterations; + stats.min_ms = times.front(); + stats.max_ms = times.back(); + stats.median_ms = times[iterations / 2]; + + double sum = 0; + for(double t : times) + sum += t; + stats.avg_ms = sum / iterations; + + return stats; +} + +// ============================================================================= +// Validation Utilities +// ============================================================================= + +/** + * @brief Validation result + */ +struct ValidationResult +{ + bool correct = false; + double max_diff = 0; + double mean_diff = 0; + double accuracy = 0; + int64_t matches = 0; + int64_t total = 0; + + void print(std::ostream& os = std::cout) const + { + os << " Correct: " << (correct ? "YES" : "NO") << "\n"; + os << " Max diff: " << max_diff << "\n"; + os << " Mean diff: " << mean_diff << "\n"; + os << " Accuracy: " << accuracy << "%\n"; + os << " Matches: " << matches << "/" << total << "\n"; + } +}; + +/** + * @brief Validate GEMM result against reference + */ +template +ValidationResult validate_result( + const T* result, const T* reference, int64_t size, double rtol = 1e-3, double atol = 1e-2) +{ + ValidationResult v; + v.total = size; + v.max_diff = 0; + v.matches = 0; + + double sum_diff = 0; + + for(int64_t i = 0; i < size; ++i) + { + double r = static_cast(result[i]); + double ref = static_cast(reference[i]); + double diff = std::abs(r - ref); + + v.max_diff = std::max(v.max_diff, diff); + sum_diff += diff; + + double threshold = atol + rtol * std::abs(ref); + if(diff <= threshold) + ++v.matches; + } + + v.mean_diff = sum_diff / size; + v.accuracy = 100.0 * v.matches / v.total; + v.correct = (v.matches == v.total) || (v.accuracy >= 99.9); + + return v; +} + +/** + * @brief Compute reference GEMM on CPU + */ +template +void compute_reference_gemm( + const AType* A, const BType* B, CType* C, int64_t M, int64_t N, int64_t K) +{ + for(int64_t m = 0; m < M; ++m) + { + for(int64_t n = 0; n < N; ++n) + { + double acc = 0; + for(int64_t k = 0; k < K; ++k) + acc += static_cast(A[m * K + k]) * static_cast(B[k * N + n]); + C[m * N + n] = static_cast(acc); + } + } +} + +// ============================================================================= +// Data Generation +// ============================================================================= + +template +void fill_random(T* data, int64_t size, T min_val = T(-1), T max_val = T(1)) +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(static_cast(min_val), + static_cast(max_val)); + for(int64_t i = 0; i < size; ++i) + data[i] = static_cast(dist(gen)); +} + +template +void fill_zeros(T* data, int64_t size) +{ + std::fill(data, data + size, T(0)); +} + +template +void fill_ones(T* data, int64_t size) +{ + std::fill(data, data + size, T(1)); +} + +template +void fill_identity(T* data, int64_t rows, int64_t cols) +{ + fill_zeros(data, rows * cols); + int64_t min_dim = std::min(rows, cols); + for(int64_t i = 0; i < min_dim; ++i) + data[i * cols + i] = T(1); +} + +// ============================================================================= +// GPU Memory Management +// ============================================================================= + +/** + * @brief RAII wrapper for GPU memory + */ +template +class GpuBuffer +{ + public: + GpuBuffer() : data_(nullptr), size_(0) {} + + explicit GpuBuffer(int64_t count) : size_(count * sizeof(T)) + { + CK_HIP_CHECK_THROW(hipMalloc(&data_, size_)); + } + + ~GpuBuffer() + { + if(data_) + (void)hipFree(data_); + } + + // Non-copyable + GpuBuffer(const GpuBuffer&) = delete; + GpuBuffer& operator=(const GpuBuffer&) = delete; + + // Movable + GpuBuffer(GpuBuffer&& other) noexcept : data_(other.data_), size_(other.size_) + { + other.data_ = nullptr; + other.size_ = 0; + } + + GpuBuffer& operator=(GpuBuffer&& other) noexcept + { + if(this != &other) + { + if(data_) + (void)hipFree(data_); + data_ = other.data_; + size_ = other.size_; + other.data_ = nullptr; + other.size_ = 0; + } + return *this; + } + + T* get() { return data_; } + const T* get() const { return data_; } + int64_t size_bytes() const { return size_; } + int64_t count() const { return size_ / sizeof(T); } + + void copy_from_host(const T* host_data) + { + CK_HIP_CHECK_THROW(hipMemcpy(data_, host_data, size_, hipMemcpyHostToDevice)); + } + + void copy_to_host(T* host_data) const + { + CK_HIP_CHECK_THROW(hipMemcpy(host_data, data_, size_, hipMemcpyDeviceToHost)); + } + + void zero() { CK_HIP_CHECK_THROW(hipMemset(data_, 0, size_)); } + + private: + T* data_; + int64_t size_; +}; + +// ============================================================================= +// Printing Utilities +// ============================================================================= + +inline void print_separator(char c = '=', int width = 70) +{ + std::cout << std::string(width, c) << "\n"; +} + +inline void print_header(const std::string& title) +{ + print_separator(); + std::cout << title << "\n"; + print_separator(); +} + +inline std::string format_size(int64_t M, int64_t N, int64_t K) +{ + std::ostringstream oss; + oss << M << "x" << N << "x" << K; + return oss.str(); +} + +inline std::string format_number(int64_t n) +{ + std::string s = std::to_string(n); + int pos = static_cast(s.length()) - 3; + while(pos > 0) + { + s.insert(pos, ","); + pos -= 3; + } + return s; +} + +/** + * @brief Print all registered kernels in a registry + * + * @param registry The registry to list kernels from + * @param os Output stream (default: std::cout) + * @param verbose If true, show full kernel config details + */ +inline void print_registered_kernels(const Registry& registry, + std::ostream& os = std::cout, + bool verbose = false) +{ + const auto& kernels = registry.get_all(); + os << "Registered Kernels (" << kernels.size() << "):\n"; + os << std::string(70, '-') << "\n"; + + int idx = 1; + for(const auto& kernel : kernels) + { + const auto& key = kernel->get_key(); + + os << " " << idx++ << ". " << kernel->get_name() << "\n"; + + if(verbose) + { + os << " Tile: " << key.algorithm.tile_shape.m << "x" + << key.algorithm.tile_shape.n << "x" << key.algorithm.tile_shape.k << "\n"; + os << " Wave: " << static_cast(key.algorithm.wave_shape.m) << "x" + << static_cast(key.algorithm.wave_shape.n) << "x" + << static_cast(key.algorithm.wave_shape.k) << "\n"; + os << " WarpTile: " << static_cast(key.algorithm.warp_tile_shape.m) << "x" + << static_cast(key.algorithm.warp_tile_shape.n) << "x" + << static_cast(key.algorithm.warp_tile_shape.k) << "\n"; + os << " Pipeline: " << to_string(key.algorithm.pipeline) << "\n"; + os << " Scheduler: " << to_string(key.algorithm.scheduler) << "\n"; + os << " Arch: " << key.gfx_arch << "\n"; + os << "\n"; + } + } + + if(!verbose && !kernels.empty()) + { + os << "\n Use --list-verbose for full details\n"; + } + os << std::string(70, '-') << "\n"; +} + +/** + * @brief Print a single kernel's configuration + */ +inline void print_kernel_info(const KernelInstance& kernel, std::ostream& os = std::cout) +{ + const auto& key = kernel.get_key(); + + os << "Kernel: " << kernel.get_name() << "\n"; + os << " Signature:\n"; + os << " dtype: " << to_string(key.signature.dtype_a) << "/" + << to_string(key.signature.dtype_b) << "/" << to_string(key.signature.dtype_c) << "\n"; + os << " layout: " << to_string(key.signature.layout_a) << to_string(key.signature.layout_b) + << to_string(key.signature.layout_c) << "\n"; + + os << " Algorithm:\n"; + os << " tile: " << key.algorithm.tile_shape.m << "x" << key.algorithm.tile_shape.n + << "x" << key.algorithm.tile_shape.k << "\n"; + os << " wave: " << static_cast(key.algorithm.wave_shape.m) << "x" + << static_cast(key.algorithm.wave_shape.n) << "x" + << static_cast(key.algorithm.wave_shape.k) << "\n"; + os << " warp_tile: " << static_cast(key.algorithm.warp_tile_shape.m) << "x" + << static_cast(key.algorithm.warp_tile_shape.n) << "x" + << static_cast(key.algorithm.warp_tile_shape.k) << "\n"; + os << " pipeline: " << to_string(key.algorithm.pipeline) << "\n"; + os << " scheduler: " << to_string(key.algorithm.scheduler) << "\n"; + os << " epilogue: " << to_string(key.algorithm.epilogue) << "\n"; + + os << " Target: " << key.gfx_arch << "\n"; +} + +// ============================================================================= +// Kernel Key Builders +// ============================================================================= + +/** + * @brief Build a KernelKey for FP16 Row-Col-Row layout GEMM + * + * This is the most common configuration. Customize parameters as needed. + */ +struct KernelKeyBuilder +{ + // Tile shape + int tile_m = 128; + int tile_n = 128; + int tile_k = 32; + + // Wave shape (warps per block) + int wave_m = 2; + int wave_n = 2; + int wave_k = 1; + + // Warp tile shape + int warp_m = 32; + int warp_n = 32; + int warp_k = 16; + + // Block size + int block_size = 256; + + // Data types + DataType dtype_a = DataType::FP16; + DataType dtype_b = DataType::FP16; + DataType dtype_c = DataType::FP16; + DataType dtype_acc = DataType::FP32; + + // Layouts + LayoutTag layout_a = LayoutTag::RowMajor; + LayoutTag layout_b = LayoutTag::ColMajor; + LayoutTag layout_c = LayoutTag::RowMajor; + + // Pipeline/scheduler + Pipeline pipeline = Pipeline::CompV4; + Scheduler scheduler = Scheduler::Intrawave; + Epilogue epilogue = Epilogue::CShuffle; + + // Features + bool preshuffle = false; + int num_d_tensors = 0; // Multi-D: number of additional input tensors + std::string elementwise_op = "PassThrough"; + + // Target GPU + std::string gfx_arch = "gfx942"; + + /** + * @brief Build the KernelKey + */ + KernelKey build() const + { + KernelKey key; + + // Signature + key.signature.dtype_a = dtype_a; + key.signature.dtype_b = dtype_b; + key.signature.dtype_c = dtype_c; + key.signature.dtype_acc = dtype_acc; + key.signature.layout_a = layout_a; + key.signature.layout_b = layout_b; + key.signature.layout_c = layout_c; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = elementwise_op; + key.signature.num_d_tensors = num_d_tensors; + key.signature.structured_sparsity = false; + + // Algorithm + key.algorithm.tile_shape = {static_cast(tile_m), + static_cast(tile_n), + static_cast(tile_k)}; + key.algorithm.wave_shape = {static_cast(wave_m), + static_cast(wave_n), + static_cast(wave_k)}; + key.algorithm.warp_tile_shape = {static_cast(warp_m), + static_cast(warp_n), + static_cast(warp_k)}; + key.algorithm.pipeline = pipeline; + key.algorithm.scheduler = scheduler; + key.algorithm.epilogue = epilogue; + key.algorithm.block_size = block_size; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = preshuffle; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + + key.gfx_arch = gfx_arch; + + return key; + } + + // Convenience preset methods + static KernelKeyBuilder fp16_rcr() { return KernelKeyBuilder{}; } + + static KernelKeyBuilder fp16_rrr() + { + auto b = KernelKeyBuilder{}; + b.layout_b = LayoutTag::RowMajor; + return b; + } + + static KernelKeyBuilder preshuffle_v1() + { + auto b = KernelKeyBuilder{}; + b.pipeline = Pipeline::PreShuffleV1; + b.preshuffle = true; + return b; + } + + static KernelKeyBuilder preshuffle_v2() + { + auto b = KernelKeyBuilder{}; + b.pipeline = Pipeline::PreShuffleV2; + b.preshuffle = true; + return b; + } + + static KernelKeyBuilder multi_d(int num_d, const std::string& op = "MultiDAdd") + { + auto b = KernelKeyBuilder{}; + b.num_d_tensors = num_d; + b.elementwise_op = op; + return b; + } +}; + +} // namespace utils +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp b/dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp new file mode 100644 index 0000000000..a7e063c3cc --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp @@ -0,0 +1,228 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/problem.hpp" +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace validation { + +/// Reference CPU GEMM implementation for validation +template +void reference_gemm_cpu(const ADataType* a, + const BDataType* b, + CDataType* c, + int M, + int N, + int K, + int stride_a, + int stride_b, + int stride_c, + bool transpose_a = false, + bool transpose_b = false) +{ + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + AccDataType acc = 0; + + for(int k = 0; k < K; ++k) + { + // Get A element + int a_idx = transpose_a ? (k * stride_a + m) : (m * stride_a + k); + AccDataType a_val = static_cast(a[a_idx]); + + // Get B element + int b_idx = transpose_b ? (n * stride_b + k) : (k * stride_b + n); + AccDataType b_val = static_cast(b[b_idx]); + + acc += a_val * b_val; + } + + // Write C element + int c_idx = m * stride_c + n; + c[c_idx] = static_cast(acc); + } + } +} + +/// Validate kernel output against reference +template +bool validate_output(const CDataType* result, + const CDataType* reference, + int size, + float rtol = 1e-3f, + float atol = 1e-5f) +{ + int errors = 0; + const int max_errors_to_print = 10; + + for(int i = 0; i < size; ++i) + { + float res_val = static_cast(result[i]); + float ref_val = static_cast(reference[i]); + + float abs_diff = std::abs(res_val - ref_val); + float abs_ref = std::abs(ref_val); + + bool is_valid = (abs_diff <= atol) || (abs_diff <= rtol * abs_ref); + + if(!is_valid) + { + if(errors < max_errors_to_print) + { + printf("Mismatch at index %d: result=%.6f, reference=%.6f, diff=%.6e\n", + i, + res_val, + ref_val, + abs_diff); + } + errors++; + } + } + + if(errors > 0) + { + printf("Validation failed: %d/%d elements mismatched (%.2f%%)\n", + errors, + size, + 100.0f * errors / size); + return false; + } + + return true; +} + +/// Validate kernel with reference implementation +template +bool validate_gemm_kernel(const void* a_dev_ptr, + const void* b_dev_ptr, + const void* c_dev_ptr, + const Problem& problem, + float rtol = 1e-3f, + float atol = 1e-5f) +{ + const int M = problem.M; + const int N = problem.N; + const int K = problem.K; + + // Allocate host memory + std::vector a_host(M * K); + std::vector b_host(K * N); + std::vector c_host(M * N); + std::vector c_ref(M * N); + + // Copy from device + hipMemcpy(a_host.data(), a_dev_ptr, M * K * sizeof(ADataType), hipMemcpyDeviceToHost); + hipMemcpy(b_host.data(), b_dev_ptr, K * N * sizeof(BDataType), hipMemcpyDeviceToHost); + hipMemcpy(c_host.data(), c_dev_ptr, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); + + // Compute reference + reference_gemm_cpu(a_host.data(), + b_host.data(), + c_ref.data(), + M, + N, + K, + K, // stride_a (row-major) + N, // stride_b (row-major) + N, // stride_c (row-major) + false, + false); + + // Validate + return validate_output(c_host.data(), c_ref.data(), M * N, rtol, atol); +} + +/// Validator class for kernel instances +class KernelValidator +{ + public: + KernelValidator(float rtol = 1e-3f, float atol = 1e-5f) : rtol_(rtol), atol_(atol) {} + + /// Validate a kernel instance + template + bool validate(KernelInstance& kernel, + const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const Problem& problem) + { + // Use kernel's validate method if available + return kernel.validate(a_ptr, b_ptr, c_ptr, problem, rtol_, atol_); + } + + /// Set tolerances + void set_tolerances(float rtol, float atol) + { + rtol_ = rtol; + atol_ = atol; + } + + /// Get tolerances + std::pair get_tolerances() const { return {rtol_, atol_}; } + + private: + float rtol_; + float atol_; +}; + +/// Helper to generate random test data +template +void generate_random_data(T* data, int size, float min_val = -1.0f, float max_val = 1.0f) +{ + for(int i = 0; i < size; ++i) + { + float rand_val = min_val + (max_val - min_val) * (rand() / (float)RAND_MAX); + data[i] = static_cast(rand_val); + } +} + +/// Helper to allocate and initialize test tensors +template +struct TestTensor +{ + T* host_ptr; + T* device_ptr; + int size; + + TestTensor(int size_) : size(size_) + { + host_ptr = new T[size]; + hipMalloc(&device_ptr, size * sizeof(T)); + } + + ~TestTensor() + { + delete[] host_ptr; + hipFree(device_ptr); + } + + void randomize(float min_val = -1.0f, float max_val = 1.0f) + { + generate_random_data(host_ptr, size, min_val, max_val); + hipMemcpy(device_ptr, host_ptr, size * sizeof(T), hipMemcpyHostToDevice); + } + + void copy_to_device() + { + hipMemcpy(device_ptr, host_ptr, size * sizeof(T), hipMemcpyHostToDevice); + } + + void copy_from_device() + { + hipMemcpy(host_ptr, device_ptr, size * sizeof(T), hipMemcpyDeviceToHost); + } + + void zero() { hipMemset(device_ptr, 0, size * sizeof(T)); } +}; + +} // namespace validation +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/python/CMakeLists.txt b/dispatcher/python/CMakeLists.txt new file mode 100644 index 0000000000..e57678952e --- /dev/null +++ b/dispatcher/python/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# This directory contains Python utilities for the dispatcher examples. +# The main utility file is ctypes_utils.py which is used by GEMM Python examples. +# Conv Python examples use their own conv_utils.py in the examples directory. + +# No build targets needed - these are pure Python utilities. +message(STATUS "Python utilities directory configured (no build targets)") diff --git a/dispatcher/python/README.md b/dispatcher/python/README.md new file mode 100644 index 0000000000..9286acbf72 --- /dev/null +++ b/dispatcher/python/README.md @@ -0,0 +1,60 @@ +# CK Tile Dispatcher Python Utilities + +This directory contains Python utilities used by the dispatcher examples. + +## Contents + +- `ctypes_utils.py` - Core ctypes utilities for GEMM Python examples + - `KernelConfig` - Kernel configuration dataclass + - `setup_gemm_dispatcher()` - Setup dispatcher with auto-correction + - `cleanup_gemm()` - Cleanup dispatcher resources + - `GemmRunner` - GPU execution helper + - Auto-correction and validation utilities + +- `conv_utils.py` - Core utilities for Conv Python examples + - `ConvSignature`, `ConvAlgorithm` - Convolution configuration + - `ConvProblem` - Problem definition + - `GpuConvRunner` - GPU execution helper + - `EnhancedConvCodegenRunner` - Kernel codegen utilities + +## Usage + +### GEMM Examples + +The GEMM Python examples in `dispatcher/examples/gemm/python/` import: + +```python +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + GemmRunner, +) +``` + +### Conv Examples + +The Conv Python examples in `dispatcher/examples/conv/python/` import: + +```python +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from conv_utils import ( + ConvSignature, + ConvAlgorithm, + ConvProblem, + GpuConvRunner, +) +``` + +## Requirements + +- Python 3.8+ +- NumPy +- HIP runtime (for GPU execution) diff --git a/dispatcher/python/ctypes_utils.py b/dispatcher/python/ctypes_utils.py new file mode 100644 index 0000000000..821fc2b08d --- /dev/null +++ b/dispatcher/python/ctypes_utils.py @@ -0,0 +1,2347 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +CK Tile Dispatcher Utilities + +Common utilities for loading, compiling, and using the CK Tile dispatcher. + +Usage: + from ck_tile_dispatcher.utils import DispatcherLib, GemmRunner, Validator + + # Option 1: Auto-compile and load + lib = DispatcherLib.auto() + + # Option 2: Load existing library + lib = DispatcherLib.load("/path/to/libdispatcher_gemm.so") + + # Run GEMM + runner = GemmRunner(lib) + result = runner.run(A, B) + + # Validate + validator = Validator() + check = validator.check(result.C, C_reference) +""" + +import ctypes +import subprocess +import numpy as np +from pathlib import Path +from typing import Optional, Tuple, List, Dict, Any +from dataclasses import dataclass, field +from concurrent.futures import ProcessPoolExecutor, as_completed +import multiprocessing +import time + + +# ============================================================================= +# Path Configuration +# ============================================================================= + + +def get_dispatcher_root() -> Path: + """Get the dispatcher root directory""" + # This file is in dispatcher/python/ + return Path(__file__).parent.parent + + +def get_ck_root() -> Path: + """Get the CK root directory""" + return get_dispatcher_root().parent + + +def get_build_dir() -> Path: + """Get the build directory""" + return get_dispatcher_root() / "build" + + +# ============================================================================= +# Supported Data Types +# ============================================================================= + +# All supported GEMM dtype combinations from warp_gemm_dispatcher.hpp +SUPPORTED_DTYPES = { + # dtype_a, dtype_b -> acc_dtype, warp_tiles + ("fp32", "fp32"): {"acc": "fp32", "warp_tiles": [(16, 16, 4), (16, 16, 16)]}, + ("fp16", "fp16"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 8), (32, 32, 16), (16, 16, 16), (16, 16, 32)], + }, + ("bf16", "bf16"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 8), (32, 32, 16), (16, 16, 16), (16, 16, 32)], + }, + ("fp8", "fp8"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 16), (32, 32, 32), (16, 16, 32), (16, 16, 64)], + }, + ("fp8", "bf8"): {"acc": "fp32", "warp_tiles": [(32, 32, 16), (16, 16, 32)]}, + ("bf8", "fp8"): {"acc": "fp32", "warp_tiles": [(32, 32, 16), (16, 16, 128)]}, + ("bf8", "bf8"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 16), (32, 32, 32), (16, 16, 32)], + }, + ("int8", "int8"): { + "acc": "int32", + "warp_tiles": [(32, 32, 16), (16, 16, 32), (16, 16, 16)], + }, + ("pk_fp4", "pk_fp4"): {"acc": "fp32", "warp_tiles": [(16, 16, 128)]}, +} + +# All valid individual dtypes +VALID_DTYPES = ["fp16", "bf16", "fp32", "fp8", "bf8", "int8", "pk_fp4"] + + +def get_generated_kernels_dir() -> Path: + """Get the generated kernels directory""" + return get_build_dir() / "generated_kernels" + + +# ============================================================================= +# Arch Filter and Validation +# ============================================================================= + + +def get_arch_filter_data() -> Dict[str, Any]: + """Load arch filter data from arch_specs_generated if available.""" + codegen_dir = get_dispatcher_root() / "codegen" + import sys + + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + TRAIT_UNSUPPORTED_COMBINATIONS, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + return { + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + # Fallback defaults + return { + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + }, + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + "gfx90a": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + + +@dataclass +class ValidationResult: + """Result of kernel config validation.""" + + is_valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + suggested_fixes: Dict[str, Any] = field(default_factory=dict) + + def print_result(self, indent: str = " "): + """Print validation result.""" + if self.is_valid: + print(f"{indent}✓ Configuration valid") + else: + print(f"{indent}⚠ Configuration has issues:") + for err in self.errors: + print(f"{indent} - {err}") + + if self.warnings: + for warn in self.warnings: + print(f"{indent} Warning: {warn}") + + if self.suggested_fixes: + print(f"{indent} Suggested fixes:") + for key, val in self.suggested_fixes.items(): + print(f"{indent} {key}: {val}") + + +def validate_kernel_config(config: "KernelConfig") -> ValidationResult: + """ + Validate a KernelConfig against arch filter rules. + + Validation considers the GEMM variant (standard, preshuffle, multi_d) + for operator-specific constraints like minimum tile sizes. + + Returns ValidationResult with is_valid, errors, and suggested fixes. + """ + arch_data = get_arch_filter_data() + + errors = [] + warnings = [] + suggested_fixes = {} + + pipeline = config.pipeline + epilogue = config.epilogue + scheduler = config.scheduler + dtype = config.dtype_a + arch = config.gfx_arch + variant = getattr(config, "variant", "standard") + + wave_m = config.wave_m + wave_n = config.wave_n + wave_k = config.wave_k + + warp_m = config.warp_m + warp_n = config.warp_n + warp_k = config.warp_k + + # Variant-specific tile constraints + if variant == "preshuffle": + # Preshuffle requires larger minimum tiles for efficiency + if config.tile_m < 64: + errors.append(f"Preshuffle requires tile_m >= 64, got {config.tile_m}") + suggested_fixes["tile_m"] = 64 + if config.tile_n < 64: + errors.append(f"Preshuffle requires tile_n >= 64, got {config.tile_n}") + suggested_fixes["tile_n"] = 64 + if config.tile_k < 32: + errors.append(f"Preshuffle requires tile_k >= 32, got {config.tile_k}") + suggested_fixes["tile_k"] = 32 + + elif variant == "multi_d": + # Multi-D has standard GEMM constraints + # Could add specific constraints here if needed + pass + + # Check trait combination (pipeline, epilogue, scheduler) + combo = (pipeline, epilogue, scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}" + ) + suggested_fixes["scheduler"] = "intrawave" + + # Check wave configuration for this arch + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}. Valid: {valid_str}" + ) + if warp_combos: + suggested_fixes["wave_m"] = warp_combos[0][0] + suggested_fixes["wave_n"] = warp_combos[0][1] + suggested_fixes["wave_k"] = warp_combos[0][2] + + # Check warp tile configuration for this arch and dtype + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}. Valid: {valid_str}" + ) + if warp_tile_combos: + suggested_fixes["warp_m"] = warp_tile_combos[0][0] + suggested_fixes["warp_n"] = warp_tile_combos[0][1] + suggested_fixes["warp_k"] = warp_tile_combos[0][2] + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}. Supported: {', '.join(arch_data['supported_archs'])}" + ) + + return ValidationResult( + is_valid=len(errors) == 0, + errors=errors, + warnings=warnings, + suggested_fixes=suggested_fixes, + ) + + +def auto_correct_kernel_config( + config: "KernelConfig", verbose: bool = False +) -> Tuple["KernelConfig", bool, List[str]]: + """ + Validate and auto-correct a KernelConfig. + + Returns (corrected_config, was_modified, corrections_list). + If the config was valid, returns (original_config, False, []). + If corrections were made, returns (new_config, True, [list of correction descriptions]). + """ + validation = validate_kernel_config(config) + + if validation.is_valid: + return config, False, [] + + # Apply suggested fixes and track what changed + from dataclasses import replace + + fixes = validation.suggested_fixes + corrections = [] + + # Check each fix and describe what changed + if "scheduler" in fixes and fixes["scheduler"] != config.scheduler: + corrections.append( + f"Scheduler: {config.scheduler} → {fixes['scheduler']} " + f"('{config.scheduler}' not supported with pipeline={config.pipeline}, epilogue={config.epilogue})" + ) + + if "wave_m" in fixes or "wave_n" in fixes or "wave_k" in fixes: + old_wave = f"[{config.wave_m}, {config.wave_n}, {config.wave_k}]" + new_wave = f"[{fixes.get('wave_m', config.wave_m)}, {fixes.get('wave_n', config.wave_n)}, {fixes.get('wave_k', config.wave_k)}]" + if old_wave != new_wave: + corrections.append( + f"Wave config: {old_wave} → {new_wave} " + f"(original not supported on {config.gfx_arch})" + ) + + if "warp_m" in fixes or "warp_n" in fixes or "warp_k" in fixes: + old_warp = f"[{config.warp_m}, {config.warp_n}, {config.warp_k}]" + new_warp = f"[{fixes.get('warp_m', config.warp_m)}, {fixes.get('warp_n', config.warp_n)}, {fixes.get('warp_k', config.warp_k)}]" + if old_warp != new_warp: + corrections.append( + f"Warp tile: {old_warp} → {new_warp} " + f"(original not supported for {config.dtype_a} on {config.gfx_arch})" + ) + + new_config = replace( + config, + scheduler=fixes.get("scheduler", config.scheduler), + wave_m=fixes.get("wave_m", config.wave_m), + wave_n=fixes.get("wave_n", config.wave_n), + wave_k=fixes.get("wave_k", config.wave_k), + warp_m=fixes.get("warp_m", config.warp_m), + warp_n=fixes.get("warp_n", config.warp_n), + warp_k=fixes.get("warp_k", config.warp_k), + ) + + return new_config, True, corrections + + +def print_kernel_config(config: "KernelConfig", title: str = "KERNEL CONFIGURATION"): + """ + Print a formatted kernel configuration for GEMM. + + Args: + config: The KernelConfig to print + title: Title to display (e.g., "REQUESTED KERNEL CONFIGURATION") + """ + print() + print("=" * 70) + print(f" {title}") + print("=" * 70) + print(f" Data Type A: {config.dtype_a}") + print(f" Data Type B: {config.dtype_b}") + print(f" Data Type C: {config.dtype_c}") + print(f" Accumulator: {config.dtype_acc}") + print() + print( + f" Layout: {config.layout} (A={config.layout_a}, B={config.layout_b}, C={config.layout_c})" + ) + print() + print(f" Tile M x N x K: {config.tile_m} x {config.tile_n} x {config.tile_k}") + print(f" Wave Config: {config.wave_m} x {config.wave_n} x {config.wave_k}") + print(f" Warp Tile: {config.warp_m} x {config.warp_n} x {config.warp_k}") + print() + print(f" Pipeline: {config.pipeline}") + print(f" Scheduler: {config.scheduler}") + print(f" Epilogue: {config.epilogue}") + print() + print(f" Target Arch: {config.gfx_arch}") + print("=" * 70) + print() + + +def print_auto_correction( + original: "KernelConfig", + corrected: "KernelConfig", + corrections: List[str], + indent: str = " ", +): + """ + Print what was auto-corrected and why. + + Args: + original: Original configuration before correction + corrected: Configuration after correction + corrections: List of correction descriptions + indent: Indentation for output + """ + if not corrections: + print(f"{indent}✓ Configuration valid - no corrections needed") + return + + print(f"\n{indent}⚠ AUTO-CORRECTION APPLIED:") + print(f"{indent}" + "-" * 50) + for correction in corrections: + print(f"{indent} • {correction}") + print(f"{indent}" + "-" * 50) + print() + + +def find_matching_kernel_header(config: "KernelConfig") -> Optional[Path]: + """ + Find a kernel header that EXACTLY matches the config. + + Uses progressively relaxed matching strategies. + """ + kernel_dir = get_generated_kernels_dir() + + dtype = config.dtype_a + layout = config.layout + pipeline = config.pipeline + scheduler = config.scheduler + tile_str = config.tile_str + wave_str = f"{config.wave_m}x{config.wave_n}x{config.wave_k}" + warp_str = f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + + # Strategy 1: Exact match with ALL parameters including warp tile + pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_{warp_str}.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 2: Match with tile and wave, any warp + pattern = ( + f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_*.hpp" + ) + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 3: Match with just tile (ignore wave/warp) + pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 4: Match with intrawave (known to work) + pattern = f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 5: Any kernel with matching dtype/layout/tile + pattern = f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + return None + + +# ============================================================================= +# Library Loading +# ============================================================================= + + +class DispatcherLib: + """Wrapper for the dispatcher dynamic library""" + + # Default library search paths (relative to dispatcher root) + SEARCH_PATHS = [ + "build/examples/libdispatcher_gemm_lib.so", + "build/libdispatcher_gemm_lib.so", + "build/examples/libdispatcher_gemm.so", + "build/lib/libdispatcher_gemm.so", + ] + + # Track loaded libraries globally for cleanup + _loaded_libs: List[Path] = [] + + def __init__(self, lib: ctypes.CDLL, path: Path): + self._lib = lib + self._path = path + self._closed = False + DispatcherLib._loaded_libs.append(path) + self._setup_functions() + + def _setup_functions(self): + """Setup ctypes function signatures""" + # Initialize + self._lib.dispatcher_initialize.argtypes = [] + self._lib.dispatcher_initialize.restype = ctypes.c_int + + # Alias for init + self._lib.dispatcher_init.argtypes = [] + self._lib.dispatcher_init.restype = ctypes.c_int + + # Get kernel count + self._lib.dispatcher_get_kernel_count.argtypes = [] + self._lib.dispatcher_get_kernel_count.restype = ctypes.c_int + + # Check if supported + self._lib.dispatcher_is_supported.argtypes = [ + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_int64, + ] + self._lib.dispatcher_is_supported.restype = ctypes.c_int + + # Run GEMM + self._lib.dispatcher_run_gemm.argtypes = [ + ctypes.c_void_p, # A + ctypes.c_void_p, # B + ctypes.c_void_p, # C + ctypes.c_int64, # M + ctypes.c_int64, # N + ctypes.c_int64, # K + ctypes.POINTER(ctypes.c_float), # time_ms + ] + self._lib.dispatcher_run_gemm.restype = ctypes.c_int + + # Get kernel name + self._lib.dispatcher_get_kernel_name.argtypes = [] + self._lib.dispatcher_get_kernel_name.restype = ctypes.c_char_p + + # Select kernel + self._lib.dispatcher_select_kernel.argtypes = [ + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_char_p, + ctypes.c_int, + ] + self._lib.dispatcher_select_kernel.restype = ctypes.c_int + + # Export JSON + self._lib.dispatcher_export_registry_json.argtypes = [] + self._lib.dispatcher_export_registry_json.restype = ctypes.c_char_p + + # Cleanup + self._lib.dispatcher_cleanup.argtypes = [] + self._lib.dispatcher_cleanup.restype = None + + @property + def path(self) -> Path: + return self._path + + def initialize(self) -> bool: + """Initialize the dispatcher""" + return self._lib.dispatcher_initialize() == 0 + + def get_kernel_count(self) -> int: + """Get number of registered kernels""" + return self._lib.dispatcher_get_kernel_count() + + def is_supported(self, M: int, N: int, K: int) -> bool: + """Check if a problem size is supported""" + return self._lib.dispatcher_is_supported(M, N, K) == 1 + + def get_kernel_name(self) -> str: + """Get the kernel name""" + name = self._lib.dispatcher_get_kernel_name() + return name.decode("utf-8") if name else "unknown" + + def select_kernel(self, M: int, N: int, K: int) -> Optional[str]: + """Select kernel for problem and return its name""" + buffer = ctypes.create_string_buffer(256) + result = self._lib.dispatcher_select_kernel(M, N, K, buffer, 256) + if result == 0: + return buffer.value.decode("utf-8") + return None + + def run_gemm( + self, A: np.ndarray, B: np.ndarray, C: np.ndarray, M: int, N: int, K: int + ) -> Tuple[int, float]: + """ + Run GEMM operation + + Returns: (status, time_ms) + status: 0 = success, -1 = error, -2 = no suitable kernel + """ + time_ms = ctypes.c_float(0.0) + + status = self._lib.dispatcher_run_gemm( + A.ctypes.data_as(ctypes.c_void_p), + B.ctypes.data_as(ctypes.c_void_p), + C.ctypes.data_as(ctypes.c_void_p), + M, + N, + K, + ctypes.byref(time_ms), + ) + + return status, time_ms.value + + def export_json(self) -> Optional[str]: + """Export registry to JSON string""" + json_ptr = self._lib.dispatcher_export_registry_json() + if json_ptr: + return json_ptr.decode("utf-8") + return None + + def export_registry_json(self) -> str: + """Alias for export_json for compatibility""" + return self.export_json() or "{}" + + def cleanup(self): + """Cleanup dispatcher resources""" + self._lib.dispatcher_cleanup() + + @classmethod + def find(cls) -> Optional[Path]: + """Find the dispatcher library""" + root = get_dispatcher_root() + + for rel_path in cls.SEARCH_PATHS: + path = root / rel_path + if path.exists(): + return path + + return None + + @classmethod + def load(cls, path: Optional[Path] = None) -> Optional["DispatcherLib"]: + """Load the dispatcher library from path or auto-find""" + if path is None: + path = cls.find() + + if path is None or not path.exists(): + return None + + try: + lib = ctypes.CDLL(str(path)) + return cls(lib, path) + except OSError as e: + print(f"Failed to load library: {e}") + return None + + @classmethod + def compile(cls, output_path: Optional[Path] = None) -> Optional[Path]: + """Compile the dispatcher library""" + root = get_dispatcher_root() + ck_root = get_ck_root() + + if output_path is None: + output_path = get_build_dir() / "examples" / "libdispatcher_gemm.so" + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Find a kernel header to include + kernel_dir = get_generated_kernels_dir() + kernel_headers = list(kernel_dir.glob("gemm_fp16_rcr_compv4*128x128x32*.hpp")) + + if not kernel_headers: + print("No kernel headers found. Generate kernels first.") + return None + + kernel_header = kernel_headers[0] + + # Use the ctypes binding source file + ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp" + if not ctypes_source.exists(): + print(f"Source file not found: {ctypes_source}") + print( + "Please build with CMake: cd build && cmake .. && make dispatcher_gemm_lib" + ) + return None + + # CK_TILE_SINGLE_KERNEL_INCLUDE exports types to global namespace for ctypes binding + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{root / 'build/generated_kernels'}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", # Enable global namespace exports + f"-include{kernel_header}", + "-D__HIP_PLATFORM_AMD__", + "--offload-arch=gfx942", + "-DAMDGPU_ARCH=gfx942", + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(output_path), + ] + + try: + result = subprocess.run( + compile_cmd, capture_output=True, text=True, timeout=120 + ) + if result.returncode == 0: + return output_path + else: + print(f"Compilation failed:\n{result.stderr}") + return None + except subprocess.TimeoutExpired: + print("Compilation timed out") + return None + + @classmethod + def auto(cls, recompile: bool = False) -> Optional["DispatcherLib"]: + """Auto-find or compile the library. + + Note: The library is built by CMake with a specific kernel configuration. + If you need a different dtype/layout, rebuild with: + cd build && cmake .. && make dispatcher_gemm_lib + """ + lib = cls.load() + if lib is not None: + if lib.initialize(): + return lib + else: + print(" Library found but failed to initialize") + print( + " Rebuild with: cd build && cmake .. && make dispatcher_gemm_lib" + ) + + # Don't fall back to old compile method - use CMake instead + print(" Library not found. Build with:") + print(" cd dispatcher/build && cmake .. && make dispatcher_gemm_lib") + return None + + +# ============================================================================= +# GEMM Runner +# ============================================================================= + + +@dataclass +class GemmResult: + """Result of a GEMM operation""" + + output: np.ndarray # The output C matrix + time_ms: float + status: int + tflops: float + kernel_name: str + + @property + def success(self) -> bool: + return self.status == 0 + + # Alias for backward compatibility + @property + def C(self) -> np.ndarray: + return self.output + + +class GemmRunner: + """High-level GEMM runner using the dispatcher""" + + def __init__(self, lib: DispatcherLib): + self.lib = lib + + def run(self, A: np.ndarray, B: np.ndarray, dtype=np.float16) -> GemmResult: + """ + Run GEMM: C = A @ B + + Args: + A: Input matrix (M x K) + B: Input matrix (K x N) + dtype: Output data type (default: float16) + + Returns: + GemmResult with output matrix and timing + """ + M, K = A.shape + K2, N = B.shape + + assert K == K2, f"Dimension mismatch: A is {M}x{K}, B is {K2}x{N}" + + # Ensure contiguous float16 arrays + A_gpu = np.ascontiguousarray(A, dtype=np.float16) + B_gpu = np.ascontiguousarray(B.T, dtype=np.float16) # Column-major + C_gpu = np.zeros((M, N), dtype=np.float16) + + # Run + status, time_ms = self.lib.run_gemm(A_gpu, B_gpu, C_gpu, M, N, K) + + # Calculate TFLOPS + flops = 2.0 * M * N * K + tflops = (flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0 + + return GemmResult( + output=C_gpu, + time_ms=time_ms, + status=status, + tflops=tflops, + kernel_name=self.lib.get_kernel_name(), + ) + + def benchmark( + self, M: int, N: int, K: int, warmup: int = 2, iterations: int = 10 + ) -> dict: + """Benchmark GEMM for given dimensions""" + A = np.random.randn(M, K).astype(np.float16) + B = np.random.randn(K, N).astype(np.float16) + + times = [] + + # Warmup + for _ in range(warmup): + self.run(A, B) + + # Benchmark + for _ in range(iterations): + result = self.run(A, B) + if result.success: + times.append(result.time_ms) + + if not times: + return {"error": "All iterations failed"} + + flops = 2.0 * M * N * K + avg_time = sum(times) / len(times) + + return { + "M": M, + "N": N, + "K": K, + "min_ms": min(times), + "avg_ms": avg_time, + "max_ms": max(times), + "tflops": (flops / (avg_time * 1e-3)) / 1e12, + "iterations": len(times), + } + + +# ============================================================================= +# Validation Utilities +# ============================================================================= + + +class Validator: + """Utilities for validating GEMM results""" + + def __init__(self, rtol: float = 1e-3, atol: float = 1e-2): + self.rtol = rtol + self.atol = atol + + def check( + self, result: np.ndarray, reference: np.ndarray + ) -> Tuple[bool, float, float]: + """ + Check if result matches reference + + Returns: (is_correct, max_diff, mean_diff) + """ + result = result.astype(np.float32) + reference = reference.astype(np.float32) + + diff = np.abs(result - reference) + max_diff = float(np.max(diff)) + mean_diff = float(np.mean(diff)) + + close = np.allclose(result, reference, rtol=self.rtol, atol=self.atol) + + return close, max_diff, mean_diff + + def compute_reference(self, A: np.ndarray, B: np.ndarray) -> np.ndarray: + """Compute reference GEMM result using NumPy""" + return np.matmul(A.astype(np.float32), B.astype(np.float32)) + + +# ============================================================================= +# Code Generation Utilities +# ============================================================================= + + +def get_codegen_path() -> Path: + """Get path to unified_gemm_codegen.py""" + return get_dispatcher_root() / "codegen" / "unified_gemm_codegen.py" + + +@dataclass +class CodegenResult: + """Result of kernel code generation""" + + success: bool + output_dir: Path + variant: str + stdout: str = "" + stderr: str = "" + kernel_count: int = 0 + elapsed_seconds: float = 0.0 + instance_names: List[str] = field(default_factory=list) + + def get_generated_kernels(self) -> List[Path]: + """Get list of generated kernel headers""" + if self.output_dir.exists(): + return list(self.output_dir.glob("*.hpp")) + return [] + + def print_instances(self, prefix: str = " "): + """Print all generated instance names.""" + for name in self.instance_names: + print(f"{prefix}{name}") + + +def _run_codegen_subprocess(args: Dict[str, Any]) -> CodegenResult: + """ + Worker function for parallel codegen execution. + + This is a module-level function to allow pickling for ProcessPoolExecutor. + """ + import sys + import subprocess + from pathlib import Path + + codegen_path = Path(args["codegen_path"]) + out_dir = Path(args["output_dir"]) + variant = args["variant"] + datatype = args["datatype"] + layout = args["layout"] + gpu_target = args["gpu_target"] + extra_args = args.get("extra_args", []) + timeout = args.get("timeout", 300) + + out_dir.mkdir(parents=True, exist_ok=True) + + start = time.time() + + # Get existing kernels before generation + existing_kernels = set(out_dir.glob("*.hpp")) if out_dir.exists() else set() + + cmd = [ + sys.executable, + str(codegen_path), + "--output-dir", + str(out_dir), + "--datatype", + datatype, + "--layout", + layout, + "--gpu-target", + gpu_target, + "--variants", + variant, + ] + + if extra_args: + cmd.extend(extra_args) + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + + # Get new kernels after generation + all_kernels = set(out_dir.glob("*.hpp")) + new_kernels = all_kernels - existing_kernels + kernel_count = len(all_kernels) + elapsed = time.time() - start + + # Build instance names list for verbose output + instance_names = sorted([k.stem for k in new_kernels]) + + return CodegenResult( + success=result.returncode == 0, + output_dir=out_dir, + variant=variant, + stdout=result.stdout, + stderr=result.stderr, + kernel_count=kernel_count, + elapsed_seconds=elapsed, + instance_names=instance_names, + ) + except subprocess.TimeoutExpired: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=variant, + stderr=f"Code generation timed out ({timeout}s)", + elapsed_seconds=time.time() - start, + ) + except Exception as e: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=variant, + stderr=str(e), + elapsed_seconds=time.time() - start, + ) + + +# ============================================================================= +# Preshuffle Utilities +# ============================================================================= + + +def preshuffle_weight_matrix( + B: np.ndarray, + warp_tile_n: int, + warp_tile_k: int, + arch: str = "gfx942", +) -> np.ndarray: + """ + Preshuffle the B (weight) matrix for optimized GEMM inference. + + This transforms the B matrix layout to match the expected memory access + pattern for preshuffle-enabled kernels. The transformation reorders data + so that warp-level loads are coalesced. + + Args: + B: Weight matrix of shape (K, N) in column-major / (K, N) layout + warp_tile_n: Warp tile size in N dimension (e.g., 32) + warp_tile_k: Warp tile size in K dimension (e.g., 16) + arch: Target GPU architecture (gfx9xx, gfx11xx, gfx12xx) + + Returns: + Shuffled B matrix with same data but reordered layout + + Example: + >>> B = np.random.randn(1024, 2048).astype(np.float16) + >>> B_shuffled = preshuffle_weight_matrix(B, warp_tile_n=32, warp_tile_k=16) + >>> # Use B_shuffled with preshuffle-enabled kernel + """ + K, N = B.shape + + # Validate dimensions are divisible by warp tiles + if N % warp_tile_n != 0: + raise ValueError(f"N ({N}) must be divisible by warp_tile_n ({warp_tile_n})") + if K % warp_tile_k != 0: + raise ValueError(f"K ({K}) must be divisible by warp_tile_k ({warp_tile_k})") + + # Architecture-specific shuffle patterns + # Based on ck_tile/host/tensor_shuffle_utils.hpp + if arch.startswith("gfx12"): + # GFX12 (RDNA4) pattern + divisor = 2 + k_abk1_per_lane = 8 + k_abk0_per_lane = warp_tile_k // divisor // k_abk1_per_lane + + if k_abk0_per_lane <= 0: + raise ValueError( + f"warp_tile_k ({warp_tile_k}) too small for GFX12 preshuffle" + ) + + # Reshape: (K, N) -> (N/warp_n, warp_n, K/warp_k, k0, div, k1) + B_view = B.T.reshape( + N // warp_tile_n, + warp_tile_n, + K // warp_tile_k, + k_abk0_per_lane, + divisor, + k_abk1_per_lane, + ) + # Permute: {0, 2, 4, 1, 3, 5} + B_shuffled = np.transpose(B_view, (0, 2, 4, 1, 3, 5)) + + elif arch.startswith("gfx11"): + # GFX11 (RDNA3) pattern - divisor = 1 + divisor = 1 + + # Reshape: (K, N) -> (N/warp_n, warp_n, K/warp_k, div, warp_k/div) + B_view = B.T.reshape( + N // warp_tile_n, + warp_tile_n, + K // warp_tile_k, + divisor, + warp_tile_k // divisor, + ) + # Permute: {0, 2, 3, 1, 4} + B_shuffled = np.transpose(B_view, (0, 2, 3, 1, 4)) + + else: + # GFX9 (CDNA) pattern - wave64 + divisor = 2 if warp_tile_n == 32 else 4 + + # Reshape: (K, N) -> (N/warp_n, warp_n, K/warp_k, div, warp_k/div) + B_view = B.T.reshape( + N // warp_tile_n, + warp_tile_n, + K // warp_tile_k, + divisor, + warp_tile_k // divisor, + ) + # Permute: {0, 2, 3, 1, 4} + B_shuffled = np.transpose(B_view, (0, 2, 3, 1, 4)) + + # Return contiguous array with same dtype + return np.ascontiguousarray(B_shuffled.reshape(-1)).reshape(B.shape) + + +def is_preshuffle_supported(arch: str) -> bool: + """Check if preshuffle is supported for the given architecture.""" + # Preshuffle is supported on CDNA (gfx9xx) and RDNA (gfx11xx, gfx12xx) + return arch.startswith(("gfx9", "gfx11", "gfx12")) + + +@dataclass +class KernelConfig: + """ + Complete kernel configuration for GEMM. + + This defines all parameters needed to generate and run a specific kernel. + """ + + # Data types + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + + # Layouts (row/col) + layout_a: str = "row" + layout_b: str = "col" + layout_c: str = "row" + + # Tile shape (work per thread block) + tile_m: int = 128 + tile_n: int = 128 + tile_k: int = 32 + + # Wave shape (warps per block) + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + + # Warp tile (elements per warp) + warp_m: int = 32 + warp_n: int = 32 + warp_k: int = 16 + + # Block configuration + block_size: int = 256 + + # Pipeline configuration + pipeline: str = "compv4" + scheduler: str = "intrawave" + epilogue: str = "cshuffle" + + # Padding (enables arbitrary problem sizes) + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + + # GPU target + gfx_arch: str = "gfx942" + + # GEMM variant (affects arch filter validation) + # "standard", "preshuffle", or "multi_d" + variant: str = "standard" + + @property + def layout(self) -> str: + """Get layout string (e.g., 'rcr' for row-col-row)""" + mapping = {"row": "r", "col": "c"} + return mapping[self.layout_a] + mapping[self.layout_b] + mapping[self.layout_c] + + @property + def tile_str(self) -> str: + """Get tile size string""" + return f"{self.tile_m}x{self.tile_n}x{self.tile_k}" + + def print_config(self, indent: str = " "): + """Pretty print the configuration.""" + print(f"{indent}KernelConfig:") + print( + f"{indent} Data types: A={self.dtype_a}, B={self.dtype_b}, C={self.dtype_c}, Acc={self.dtype_acc}" + ) + print( + f"{indent} Layouts: A={self.layout_a}, B={self.layout_b}, C={self.layout_c} ({self.layout})" + ) + print(f"{indent} Tile: {self.tile_m}x{self.tile_n}x{self.tile_k}") + print(f"{indent} Waves: {self.wave_m}x{self.wave_n}x{self.wave_k}") + print(f"{indent} Warp tile: {self.warp_m}x{self.warp_n}x{self.warp_k}") + print(f"{indent} Block size: {self.block_size}") + print(f"{indent} Pipeline: {self.pipeline}/{self.scheduler}/{self.epilogue}") + print(f"{indent} Padding: M={self.pad_m}, N={self.pad_n}, K={self.pad_k}") + print(f"{indent} Target: {self.gfx_arch}") + + +class CodegenRunner: + """ + Runner for the unified GEMM code generator with parallel execution support. + + Usage: + codegen = CodegenRunner() + + # Generate standard kernels + result = codegen.generate("standard") + + # Generate preshuffle kernels + result = codegen.generate("preshuffle") + + # Generate multi-D kernels + result = codegen.generate("multi_d") + + # Generate all variants IN PARALLEL + results = codegen.generate_all_parallel() + + # Generate multiple configs IN PARALLEL + configs = [KernelConfig(...), KernelConfig(...)] + results = codegen.generate_configs_parallel(configs) + + # Generate with custom output directory + result = codegen.generate("standard", output_dir=Path("/custom/path")) + + # Generate from specific config + config = KernelConfig(tile_m=256, tile_n=256, tile_k=64) + result = codegen.generate_from_config(config) + """ + + VARIANTS = ["standard", "preshuffle", "multi_d"] + + def __init__( + self, + codegen_path: Optional[Path] = None, + output_dir: Optional[Path] = None, + datatype: str = "fp16", + layout: str = "rcr", + gpu_target: str = "gfx942", + max_workers: Optional[int] = None, + ): + self.codegen_path = codegen_path or get_codegen_path() + self.output_dir = output_dir or get_generated_kernels_dir() + self.datatype = datatype + self.layout = layout + self.gpu_target = gpu_target + # Default to CPU count, but cap at reasonable value + self.max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + + def _make_args( + self, + variant: str, + output_dir: Optional[Path] = None, + extra_args: Optional[List[str]] = None, + timeout: int = 300, + show_instances: bool = False, + ) -> Dict[str, Any]: + """Build args dict for parallel worker.""" + return { + "codegen_path": str(self.codegen_path), + "output_dir": str(output_dir or self.output_dir), + "variant": variant, + "datatype": self.datatype, + "layout": self.layout, + "gpu_target": self.gpu_target, + "extra_args": extra_args or [], + "timeout": timeout, + "show_instances": show_instances, + } + + def generate( + self, + variant: str = "standard", + output_dir: Optional[Path] = None, + extra_args: Optional[List[str]] = None, + show_instances: bool = False, + ) -> CodegenResult: + """ + Generate kernels for a specific variant (single-threaded). + + Args: + variant: One of "standard", "preshuffle", "multi_d" + output_dir: Override output directory + extra_args: Additional arguments to pass to codegen + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + CodegenResult with generation status and info + """ + args = self._make_args( + variant, output_dir, extra_args, show_instances=show_instances + ) + result = _run_codegen_subprocess(args) + + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + + return result + + def generate_all(self, output_dir: Optional[Path] = None) -> List[CodegenResult]: + """Generate all variants sequentially (use generate_all_parallel for speed).""" + results = [] + for variant in self.VARIANTS: + result = self.generate(variant, output_dir) + results.append(result) + return results + + def generate_all_parallel( + self, + output_dir: Optional[Path] = None, + variants: Optional[List[str]] = None, + verbose: bool = True, + show_instances: bool = False, + ) -> List[CodegenResult]: + """ + Generate all variants IN PARALLEL. + + Args: + output_dir: Override output directory + variants: List of variants to generate (default: all) + verbose: Print progress + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + List of CodegenResult for each variant + """ + variants = variants or self.VARIANTS + start_total = time.time() + + if verbose: + print( + f"Generating {len(variants)} variants in parallel (workers={self.max_workers})..." + ) + + # Build args for each variant + args_list = [self._make_args(v, output_dir) for v in variants] + for args in args_list: + args["show_instances"] = show_instances + + results = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_codegen_subprocess, args): args["variant"] + for args in args_list + } + + for future in as_completed(futures): + variant = futures[future] + try: + result = future.result() + results.append(result) + if verbose: + status = "✓" if result.success else "✗" + print( + f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" + ) + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + except Exception as e: + results.append( + CodegenResult( + success=False, + output_dir=output_dir or self.output_dir, + variant=variant, + stderr=str(e), + ) + ) + if verbose: + print(f" ✗ {variant}: FAILED - {e}") + + total_time = time.time() - start_total + if verbose: + total_kernels = sum(r.kernel_count for r in results) + print(f"Total: {total_kernels} kernels in {total_time:.2f}s") + + return results + + def generate_configs_parallel( + self, + configs: List["KernelConfig"], + output_dir: Optional[Path] = None, + verbose: bool = True, + show_instances: bool = False, + ) -> List[CodegenResult]: + """ + Generate kernels from multiple configs IN PARALLEL. + + Each config generates independently, allowing maximum parallelism. + + Args: + configs: List of KernelConfig objects + output_dir: Override output directory + verbose: Print progress + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + List of CodegenResult for each config + """ + start_total = time.time() + out_dir = output_dir or self.output_dir + + if verbose: + print( + f"Generating {len(configs)} configs in parallel (workers={self.max_workers})..." + ) + + results = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = {} + for config in configs: + args = { + "codegen_path": str(self.codegen_path), + "output_dir": str(out_dir), + "variant": "standard", + "datatype": config.dtype_a, + "layout": config.layout, + "gpu_target": config.gfx_arch, + "extra_args": [], + "timeout": 300, + "show_instances": show_instances, + } + future = executor.submit(_run_codegen_subprocess, args) + futures[future] = config.tile_str + + for future in as_completed(futures): + tile_str = futures[future] + try: + result = future.result() + results.append(result) + if verbose: + status = "✓" if result.success else "✗" + print( + f" {status} {tile_str}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" + ) + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + except Exception as e: + results.append( + CodegenResult( + success=False, + output_dir=out_dir, + variant=f"config:{tile_str}", + stderr=str(e), + ) + ) + if verbose: + print(f" ✗ {tile_str}: FAILED - {e}") + + total_time = time.time() - start_total + if verbose: + total_kernels = sum(r.kernel_count for r in results) + print(f"Total: {total_kernels} kernels in {total_time:.2f}s") + + return results + + def generate_batch_parallel( + self, + batch: List[Dict[str, Any]], + verbose: bool = True, + show_instances: bool = False, + ) -> List[CodegenResult]: + """ + Generate a batch of kernel specs IN PARALLEL. + + This is the most flexible parallel generation method. + + Args: + batch: List of dicts with keys: variant, datatype, layout, gpu_target, output_dir + verbose: Print progress + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + List of CodegenResult + """ + start_total = time.time() + + if verbose: + print( + f"Generating {len(batch)} kernel specs in parallel (workers={self.max_workers})..." + ) + + # Build args for each spec + args_list = [] + for spec in batch: + args = { + "codegen_path": str(self.codegen_path), + "output_dir": str(spec.get("output_dir", self.output_dir)), + "variant": spec.get("variant", "standard"), + "datatype": spec.get("datatype", self.datatype), + "layout": spec.get("layout", self.layout), + "gpu_target": spec.get("gpu_target", self.gpu_target), + "extra_args": spec.get("extra_args", []), + "timeout": spec.get("timeout", 300), + "show_instances": show_instances, + } + args_list.append(args) + + results = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_codegen_subprocess, args): args["variant"] + for args in args_list + } + + for future in as_completed(futures): + variant = futures[future] + try: + result = future.result() + results.append(result) + if verbose: + status = "✓" if result.success else "✗" + print( + f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" + ) + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + except Exception as e: + results.append( + CodegenResult( + success=False, + output_dir=self.output_dir, + variant=variant, + stderr=str(e), + ) + ) + if verbose: + print(f" ✗ {variant}: FAILED - {e}") + + total_time = time.time() - start_total + if verbose: + total_kernels = sum(r.kernel_count for r in results) + print(f"Total: {total_kernels} kernels in {total_time:.2f}s") + + return results + + def generate_from_config( + self, + config: KernelConfig, + output_dir: Optional[Path] = None, + force: bool = False, + show_instances: bool = False, + ) -> CodegenResult: + """ + Generate kernel from a specific KernelConfig. + + This generates ONLY the specific kernel header needed (not all kernels). + Note: This does NOT rebuild the library - use build_library_for_configs() + for that. + + Args: + config: KernelConfig with all kernel parameters + output_dir: Override output directory + force: Force regeneration even if kernel exists + show_instances: Print instance names when generating + + Returns: + CodegenResult with the specific kernel + """ + import sys + import json + import tempfile + + out_dir = output_dir or self.output_dir + out_dir.mkdir(parents=True, exist_ok=True) + + # Build kernel filename pattern for this config + # Note: padding flags may differ from config (arch filter may enable padding) + tile_str = config.tile_str # e.g., "128x128x32" + wave_str = f"{config.wave_m}x{config.wave_n}x{config.wave_k}" + warp_str = f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + + # Build pattern - use * for padding flags since arch filter may change them + precise_pattern = f"gemm_{config.dtype_a}_{config.layout}_{config.pipeline}_{config.epilogue}_{config.scheduler}_*_*_*_*_{tile_str}_{wave_str}_{warp_str}.hpp" + + # Check if exact kernel already exists + existing = list(out_dir.glob(precise_pattern)) + if existing and not force: + instance_names = sorted([k.stem for k in existing]) + if show_instances: + for name in instance_names: + print(f" Kernel exists: {name}") + + return CodegenResult( + success=True, + output_dir=out_dir, + variant=f"config:{tile_str}", + kernel_count=len(existing), + instance_names=instance_names, + stdout=f"Kernel exists, using: {existing[0].name}", + ) + + if not self.codegen_path.exists(): + return CodegenResult( + success=False, + output_dir=out_dir, + variant=f"config:{tile_str}", + stderr=f"Codegen not found at {self.codegen_path}", + ) + + start = time.time() + + # Create a temporary config file for single-kernel generation + # Format must match what unified_gemm_codegen.py expects + single_config = { + "tile_config": { + "tile_m": [config.tile_m], + "tile_n": [config.tile_n], + "tile_k": [config.tile_k], + "warp_m": [config.wave_m], + "warp_n": [config.wave_n], + "warp_k": [config.wave_k], + "warp_tile_m": [config.warp_m], + "warp_tile_n": [config.warp_n], + "warp_tile_k": [config.warp_k], + }, + "trait_config": { + "pipeline": [config.pipeline], + "epilogue": [config.epilogue], + "scheduler": [config.scheduler], + "pad_m": [config.pad_m], + "pad_n": [config.pad_n], + "pad_k": [config.pad_k], + "persistent": [False], + }, + } + + # Write temp config file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(single_config, f) + config_file = f.name + + try: + # Generate ONLY this specific kernel using config file + cmd = [ + sys.executable, + str(self.codegen_path), + "--output-dir", + str(out_dir), + "--datatype", + config.dtype_a, + "--layout", + config.layout, + "--gpu-target", + config.gfx_arch, + "--config", + config_file, + "--variants", + "standard", + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + + # Find the generated kernel + matching = list(out_dir.glob(precise_pattern)) + kernel_count = len(matching) + elapsed = time.time() - start + + instance_names = sorted([k.stem for k in matching]) + if show_instances and instance_names: + for name in instance_names: + print(f" Generated: {name}") + + return CodegenResult( + success=result.returncode == 0 and kernel_count > 0, + output_dir=out_dir, + variant=f"config:{tile_str}", + stdout=result.stdout, + stderr=result.stderr, + kernel_count=kernel_count, + elapsed_seconds=elapsed, + instance_names=instance_names, + ) + except Exception as e: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=f"config:{tile_str}", + stderr=str(e), + ) + finally: + # Clean up temp file + import os + + try: + os.unlink(config_file) + except Exception: + pass + + def _rebuild_library_for_config( + self, config: KernelConfig, kernel_header: Path + ) -> Optional[Path]: + """ + Rebuild the library with the specified kernel header using hipcc directly. + + This compiles a new library with exactly the kernel specified. + Builds to a UNIQUE filename to avoid conflicts with loaded libraries. + + Architecture Note - C++ vs Python Paths: + ----------------------------------------- + C++ Multi-Kernel Path: + - Each kernel is in its own namespace (ns_gemm_...) + - Multiple kernel headers can be included together + - Uses namespace-qualified types: ns_...:SelectedKernel + - Does NOT define CK_TILE_SINGLE_KERNEL_INCLUDE + - Registration code uses block-scoped type aliases + + Python Single-Kernel JIT Path (this function): + - Each library contains exactly ONE kernel + - Uses -DCK_TILE_SINGLE_KERNEL_INCLUDE to export types to global namespace + - gemm_ctypes_lib.cpp expects: SelectedKernel, KERNEL_NAME, ADataType, etc. + - Different configs get different library files (by dtype/layout) + - This enables Python to use any kernel config without pre-building all + + Returns: Path to new library, or None on failure + """ + build_dir = get_build_dir() + # Use unique filename based on dtype/layout to avoid overwriting loaded library + lib_name = f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_lib.so" + lib_path = build_dir / "examples" / lib_name + + print(f" Rebuilding library: {lib_name}") + print(f" With kernel: {kernel_header.name}") + + root = get_dispatcher_root() + ck_root = root.parent + + ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp" + if not ctypes_source.exists(): + print(f" Source not found: {ctypes_source}") + return None + + # Link against the static dispatcher library (contains Registry, Dispatcher) + static_lib = build_dir / "libck_tile_dispatcher.a" + if not static_lib.exists(): + print(f" Static library not found: {static_lib}") + print(" Build with: cd build && cmake .. && make ck_tile_dispatcher") + return None + + # Compile source to object first, then link + obj_file = lib_path.with_suffix(".o") + + # Step 1: Compile source to object + # CK_TILE_SINGLE_KERNEL_INCLUDE enables global namespace exports in the kernel header + # This exports: SelectedKernel, KERNEL_NAME, ADataType, BDataType, CDataType, AccDataType + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", # Compile only + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{root / 'build/generated_kernels'}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", # Enable global namespace exports + f"-include{kernel_header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={config.gfx_arch}", + f'-DGFX_ARCH="{config.gfx_arch}"', # Pass arch as string for gemm_ctypes_lib.cpp + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + + try: + print(" Compiling source...") + result = subprocess.run( + compile_cmd, capture_output=True, text=True, timeout=300 + ) + if result.returncode != 0: + print(f" Compilation failed: {result.stderr[:300]}") + return None + + # Step 2: Link object with static library into shared library + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={config.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + print(" Linking...") + result = subprocess.run( + link_cmd, capture_output=True, text=True, timeout=300 + ) + if result.returncode == 0: + print(f" ✓ Library rebuilt: {lib_path.name}") + # Clean up object file + obj_file.unlink(missing_ok=True) + return lib_path + else: + print(f" Linking failed: {result.stderr[:300]}") + return None + except subprocess.TimeoutExpired: + print(" Build timed out") + return None + except Exception as e: + print(f" Build error: {e}") + return None + + def generate_preselected( + self, preset: str = "fp16_rcr_essential", output_dir: Optional[Path] = None + ) -> CodegenResult: + """ + Generate kernels from a preselected set. + + Args: + preset: Preselected kernel set name (e.g., "fp16_rcr_essential") + output_dir: Override output directory + + Returns: + CodegenResult + """ + import sys + + out_dir = output_dir or self.output_dir + out_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, + str(self.codegen_path), + "--output-dir", + str(out_dir), + "--preselected", + preset, + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + kernel_count = len(list(out_dir.glob("*.hpp"))) + + return CodegenResult( + success=result.returncode == 0, + output_dir=out_dir, + variant=f"preselected:{preset}", + stdout=result.stdout, + stderr=result.stderr, + kernel_count=kernel_count, + ) + except Exception as e: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=f"preselected:{preset}", + stderr=str(e), + ) + + def ensure_kernels_exist(self) -> bool: + """ + Ensure kernel headers exist, generating if necessary. + + Returns: + True if kernels exist or were successfully generated + """ + if self.output_dir.exists(): + kernels = list(self.output_dir.glob("*.hpp")) + if kernels: + return True + + # Generate standard kernels + result = self.generate("standard") + return result.success + + def list_kernels(self) -> List[Path]: + """List all generated kernel headers""" + if self.output_dir.exists(): + return sorted(self.output_dir.glob("*.hpp")) + return [] + + def categorize_kernels(self) -> dict: + """ + Categorize kernels by tile size and variant. + + Returns: + Dict with categories by tile size and variant type + """ + kernels = self.list_kernels() + + # Separate by variant first + preshuffle = [k for k in kernels if "_preshuffle" in k.name] + multi_d = [k for k in kernels if "_multid_" in k.name] + standard = [ + k + for k in kernels + if "_preshuffle" not in k.name and "_multid_" not in k.name + ] + + # Categorize standard kernels by tile size + compute = [k for k in standard if "_256x" in k.name] + memory = [k for k in standard if "_128x" in k.name] + latency = [k for k in standard if "_64x" in k.name or "_32x" in k.name] + + return { + "total": len(kernels), + "standard": len(standard), + "compute": compute, + "memory": memory, + "latency": latency, + "preshuffle": preshuffle, + "multi_d": multi_d, + } + + +# ============================================================================= +# Registry and Dispatcher (Explicit API) +# ============================================================================= + + +class Registry: + """ + Kernel registry - stores and manages kernel instances. + + This provides an explicit registry API that mirrors the C++ Registry class. + + Usage: + registry = Registry() + registry.register_kernel(kernel_config) + dispatcher = Dispatcher(registry) + """ + + def __init__(self, lib: Optional[DispatcherLib] = None, name: str = "default"): + self._lib = lib + self._name = name + self._kernels: List[KernelConfig] = [] + + @property + def name(self) -> str: + return self._name + + @property + def kernel_count(self) -> int: + if self._lib: + return self._lib.get_kernel_count() + return len(self._kernels) + + def register_kernel(self, config: KernelConfig) -> bool: + """Register a kernel configuration.""" + self._kernels.append(config) + return True + + def get_kernels(self) -> List[KernelConfig]: + """Get all registered kernel configs.""" + return self._kernels.copy() + + def clear(self): + """Clear all kernels.""" + self._kernels.clear() + + def bind_library(self, lib: DispatcherLib): + """Bind to a loaded dispatcher library.""" + self._lib = lib + + def __repr__(self) -> str: + return f"Registry(name='{self._name}', kernels={self.kernel_count})" + + +class Dispatcher: + """ + Kernel dispatcher - selects and runs kernels for problems. + + This provides an explicit dispatcher API that mirrors the C++ Dispatcher class. + + Usage: + registry = Registry() + registry.register_kernel(config) + + dispatcher = Dispatcher(registry) + result = dispatcher.run(A, B, M, N, K) + """ + + def __init__(self, registry: Registry, lib: Optional[DispatcherLib] = None): + self._registry = registry + self._lib = lib or registry._lib + + @property + def registry(self) -> Registry: + return self._registry + + def select_kernel(self, M: int, N: int, K: int) -> Optional[str]: + """Select best kernel for problem dimensions.""" + if self._lib: + return self._lib.select_kernel(M, N, K) + # Fallback: return first matching kernel + for config in self._registry.get_kernels(): + return f"kernel_{config.tile_str}" + return None + + def is_supported(self, M: int, N: int, K: int) -> bool: + """Check if problem size is supported.""" + if self._lib: + return self._lib.is_supported(M, N, K) + return len(self._registry.get_kernels()) > 0 + + def run(self, A: np.ndarray, B: np.ndarray, M: int, N: int, K: int) -> GemmResult: + """ + Run GEMM: C = A @ B + + Args: + A: Input matrix (M x K) + B: Input matrix (K x N) + M, N, K: Problem dimensions + + Returns: + GemmResult with output and timing + """ + if self._lib is None: + raise RuntimeError("Dispatcher not bound to library") + + # Ensure contiguous float16 arrays + A_gpu = np.ascontiguousarray(A, dtype=np.float16) + B_gpu = np.ascontiguousarray(B.T, dtype=np.float16) # Column-major + C_gpu = np.zeros((M, N), dtype=np.float16) + + # Run via library + status, time_ms = self._lib.run_gemm(A_gpu, B_gpu, C_gpu, M, N, K) + + # Calculate TFLOPS + flops = 2.0 * M * N * K + tflops = (flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0 + + return GemmResult( + output=C_gpu, + time_ms=time_ms, + status=status, + tflops=tflops, + kernel_name=self._lib.get_kernel_name() if self._lib else "unknown", + ) + + def __repr__(self) -> str: + return f"Dispatcher(registry={self._registry.name}, kernels={self._registry.kernel_count})" + + +# ============================================================================= +# Main (self-test) +# ============================================================================= + +if __name__ == "__main__": + print("CK Tile Dispatcher Utils Self-Test") + print("=" * 60) + + # Test library loading + print("\n1. Loading library...") + lib = DispatcherLib.auto() + if lib is None: + print(" FAILED: Could not load library") + exit(1) + print(f" OK: Loaded from {lib.path}") + print(f" Kernel: {lib.get_kernel_name()}") + print(f" Registered kernels: {lib.get_kernel_count()}") + + # Test GEMM + print("\n2. Running GEMM 256x256x256...") + runner = GemmRunner(lib) + A = np.random.randn(256, 256).astype(np.float16) + B = np.random.randn(256, 256).astype(np.float16) + + result = runner.run(A, B) + print(f" Status: {'OK' if result.success else 'FAILED'}") + print(f" Time: {result.time_ms:.4f} ms") + print(f" TFLOPS: {result.tflops:.2f}") + + # Test validation + print("\n3. Validating result...") + validator = Validator() + reference = validator.compute_reference(A, B) + correct, max_diff, mean_diff = validator.check(result.output, reference) + print(f" Correct: {correct}") + print(f" Max diff: {max_diff:.6f}") + + print("\n" + "=" * 60) + print("All tests passed!") + + +# ============================================================================= +# High-Level Helper Functions +# ============================================================================= + + +@dataclass +class GemmSetupResult: + """Result of setup_gemm_dispatcher""" + + success: bool + dispatcher: Optional[Dispatcher] = None + lib: Optional[DispatcherLib] = None + registry: Optional[Registry] = None + codegen: Optional[CodegenRunner] = None + config: Optional[KernelConfig] = None + kernel_header: Optional[Path] = None + error: str = "" + corrections: List[str] = field(default_factory=list) + + +def setup_gemm_dispatcher( + config: KernelConfig, + registry_name: str = "gemm_registry", + verbose: bool = True, + auto_rebuild: bool = True, +) -> GemmSetupResult: + """ + High-level helper to setup a GEMM dispatcher from a kernel config. + + This handles: + 1. Validate config against arch filter (auto-correct if needed) + 2. Generate kernel code if needed + 3. Find matching kernel header + 4. Load or rebuild library (if dtype mismatch) + 5. Create registry and dispatcher + + Args: + config: KernelConfig with all parameters + registry_name: Name for the registry + verbose: Print progress messages + auto_rebuild: Rebuild library if dtype doesn't match + + Returns: + GemmSetupResult with dispatcher, lib, registry, etc. + """ + result = GemmSetupResult(success=False, config=config) + + def log(msg): + if verbose: + print(msg) + + # Step 1: Validate config + log(" Validating config...") + validation = validate_kernel_config(config) + if not validation.is_valid: + log(" ⚠ Auto-correcting configuration...") + config, was_modified, corrections = auto_correct_kernel_config( + config, verbose=verbose + ) + result.config = config + result.corrections = corrections + # Note: corrections will be displayed by the caller via print_auto_correction + + # Step 2: Setup codegen and generate kernel + log(f" Generating kernel (tile={config.tile_str})...") + codegen = CodegenRunner( + datatype=config.dtype_a, + layout=config.layout, + gpu_target=config.gfx_arch, + ) + result.codegen = codegen + + codegen_result = codegen.generate_from_config(config) + if not codegen_result.success: + log(" ⚠ Kernel generation: using existing") + + # Step 3: Find matching kernel header + kernel_header = find_matching_kernel_header(config) + result.kernel_header = kernel_header + if not kernel_header: + log(" ⚠ No matching kernel header found") + + # Step 4: Load library + log(" Loading library...") + lib = DispatcherLib.auto() + if lib is None: + result.error = "Could not load dispatcher library" + return result + result.lib = lib + + # Check if library kernel matches config - rebuild if ANY parameter differs + lib_kernel = lib.get_kernel_name() + needs_rebuild = False + mismatches = [] + + if lib_kernel: + # Build expected kernel signature components from config + expected_parts = { + "dtype": config.dtype_a, + "layout": config.layout, + "pipeline": config.pipeline, + "epilogue": config.epilogue, + "scheduler": config.scheduler, + "tile": f"{config.tile_m}x{config.tile_n}x{config.tile_k}", + "wave": f"{config.wave_m}x{config.wave_n}x{config.wave_k}", + "warp": f"{config.warp_m}x{config.warp_n}x{config.warp_k}", + } + + # Check each component against the library kernel name + for name, expected in expected_parts.items(): + if expected not in lib_kernel: + needs_rebuild = True + mismatches.append(f"{name}={expected}") + + if needs_rebuild and auto_rebuild: + log(f" Library kernel doesn't match config: {', '.join(mismatches)}") + log(" Rebuilding library for exact config match...") + + # First ensure we have a kernel header for this exact config + if not kernel_header: + # Generate kernel for the exact config + log(" Generating kernel for config...") + codegen_result = codegen.generate_from_config(config, force=True) + kernel_header = find_matching_kernel_header(config) + result.kernel_header = kernel_header + + if kernel_header: + new_lib_path = codegen._rebuild_library_for_config(config, kernel_header) + if new_lib_path: + lib = DispatcherLib.load(new_lib_path) + if lib is None or not lib.initialize(): + result.error = "Failed to load rebuilt library" + return result + result.lib = lib + log(f" ✓ Rebuilt library: {lib.get_kernel_name()}") + else: + log(" ⚠ Rebuild failed, using existing library") + else: + log(" ⚠ No kernel header found for config, using existing library") + + # Step 5: Create registry and dispatcher + log(" Creating registry and dispatcher...") + registry = Registry(name=registry_name, lib=lib) + registry.register_kernel(config) + result.registry = registry + + dispatcher = Dispatcher(registry=registry, lib=lib) + result.dispatcher = dispatcher + + log(f" ✓ Ready: {lib.get_kernel_name()}") + + result.success = True + return result + + +def cleanup_gemm(): + """ + Cleanup function to call after running GEMM examples. + + This helps ensure clean state between examples by: + 1. Clearing any global state + 2. Suggesting garbage collection + """ + import gc + + # Clear loaded libraries list + DispatcherLib._loaded_libs.clear() + + # Suggest garbage collection + gc.collect() + + +def cleanup_generated_kernels( + keep_default: bool = True, + verbose: bool = False, +) -> int: + """ + Clean up generated kernel files. + + Call this at the start of examples to ensure fresh state. + + Args: + keep_default: Keep the default fp16 kernel (True) or delete all (False) + verbose: Print what's being deleted + + Returns: + Number of files deleted + """ + + kernel_dir = get_generated_kernels_dir() + if not kernel_dir.exists(): + return 0 + + deleted = 0 + + # Default kernel pattern to keep + default_pattern = ( + "gemm_fp16_rcr_compv4_cshuffle_intrawave_*_128x128x32_2x2x1_16x16x16.hpp" + ) + + for f in kernel_dir.glob("*.hpp"): + # Skip dispatcher_wrappers directory + if f.is_dir(): + continue + + # Optionally keep default kernel + if keep_default and f.match(default_pattern): + continue + + if verbose: + print(f" Deleting: {f.name}") + f.unlink() + deleted += 1 + + # Also clean up any temp libs + build_dir = get_build_dir() + examples_dir = build_dir / "examples" + if examples_dir.exists(): + for f in examples_dir.glob("libdispatcher_gemm_*_lib.so"): + if f.name != "libdispatcher_gemm_lib.so": + if verbose: + print(f" Deleting: {f.name}") + f.unlink() + deleted += 1 + + return deleted + + +def reset_for_example(verbose: bool = False): + """ + Reset state for a fresh example run. + + Call this at the START of each example to ensure clean state. + Cleans up generated kernels (except default) and resets globals. + """ + # Cleanup any previously generated kernels + deleted = cleanup_generated_kernels(keep_default=True, verbose=verbose) + if verbose and deleted > 0: + print(f" Cleaned up {deleted} generated files") + + # Clear any cached state + cleanup_gemm() + + +# Main (self-test) +# ============================================================================= + +if __name__ == "__main__": + print("CK Tile Dispatcher Utils Self-Test") + print("=" * 60) + + # Test library loading + print("\n1. Loading library...") + lib = DispatcherLib.auto() + if lib is None: + print(" FAILED: Could not load library") + exit(1) + print(f" OK: Loaded from {lib.path}") + print(f" Kernel: {lib.get_kernel_name()}") + print(f" Registered kernels: {lib.get_kernel_count()}") + + # Test GEMM + print("\n2. Running GEMM 256x256x256...") + runner = GemmRunner(lib) + A = np.random.randn(256, 256).astype(np.float16) + B = np.random.randn(256, 256).astype(np.float16) + + result = runner.run(A, B) + print(f" Status: {'OK' if result.success else 'FAILED'}") + print(f" Time: {result.time_ms:.4f} ms") + print(f" TFLOPS: {result.tflops:.2f}") + + # Test validation + print("\n3. Validating result...") + validator = Validator() + reference = validator.compute_reference(A, B) + correct, max_diff, mean_diff = validator.check(result.output, reference) + print(f" Correct: {correct}") + print(f" Max diff: {max_diff:.6f}") + + # Test high-level helper + print("\n4. Testing setup_gemm_dispatcher...") + config = KernelConfig(tile_m=128, tile_n=128, tile_k=32) + setup = setup_gemm_dispatcher(config, verbose=True) + print(f" Success: {setup.success}") + + # Cleanup + cleanup_gemm() + + print("\n" + "=" * 60) + print("All tests passed!") diff --git a/dispatcher/python/pytest.ini b/dispatcher/python/pytest.ini new file mode 100644 index 0000000000..08cd235fda --- /dev/null +++ b/dispatcher/python/pytest.ini @@ -0,0 +1,43 @@ +[pytest] +# Pytest configuration for CK Tile Dispatcher Python tests + +# Test discovery +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Test paths +testpaths = tests + +# Options +addopts = + -v + --strict-markers + --tb=short + --color=yes + --durations=10 + +# Markers +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + cuda: marks tests requiring CUDA/ROCm + torch: marks tests requiring PyTorch + integration: marks integration tests + unit: marks unit tests + +# Coverage +[coverage:run] +source = . +omit = + */tests/* + */examples/* + setup.py + +[coverage:report] +precision = 2 +show_missing = True +skip_covered = False + +[coverage:html] +directory = htmlcov + diff --git a/dispatcher/python/requirements.txt b/dispatcher/python/requirements.txt new file mode 100644 index 0000000000..9d429235f7 --- /dev/null +++ b/dispatcher/python/requirements.txt @@ -0,0 +1,22 @@ +# Core dependencies +numpy>=1.19.0 + +# Optional dependencies (install with pip install -e ".[torch]") +# torch>=2.0.0 + +# Development dependencies (install with pip install -e ".[dev]") +# pytest>=6.0.0 +# pytest-cov>=2.0.0 +# black>=21.0 +# flake8>=3.9.0 +# mypy>=0.910 +# isort>=5.0.0 + +# Visualization dependencies (install with pip install -e ".[viz]") +# matplotlib>=3.3.0 +# seaborn>=0.11.0 + +# Documentation dependencies +# sphinx>=4.0.0 +# sphinx-rtd-theme>=1.0.0 + diff --git a/dispatcher/scripts/compile_gemm_examples.py b/dispatcher/scripts/compile_gemm_examples.py new file mode 100644 index 0000000000..b19c18a13a --- /dev/null +++ b/dispatcher/scripts/compile_gemm_examples.py @@ -0,0 +1,2253 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Cross-platform build script for declarative kernel workflow. + +Uses existing ctypes_utils.py for path management and codegen. + +Usage: + python3 compile_gemm_examples.py [output_name] + +Example: + python3 compile_gemm_examples.py examples/cpp/01_basic_gemm.cpp my_app +""" + +import argparse +import os +import re +import subprocess +import sys +from pathlib import Path +import shutil + +# Add dispatcher/python to path to reuse existing utilities +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +# Import existing utilities (after sys.path modification) +from ctypes_utils import ( # noqa: E402 + get_dispatcher_root, + get_ck_root, + get_build_dir, + get_generated_kernels_dir, + CodegenRunner, +) + + +# ============================================================================= +# Terminal Colors (cross-platform) +# ============================================================================= + + +class Colors: + if sys.platform != "win32" and sys.stdout.isatty(): + GREEN = "\033[0;32m" + YELLOW = "\033[1;33m" + RED = "\033[0;31m" + NC = "\033[0m" + else: + GREEN = YELLOW = RED = NC = "" + + +def print_phase(msg: str): + print(f"{Colors.YELLOW}{msg}{Colors.NC}") + + +def print_success(msg: str): + print(f"{Colors.GREEN}{msg}{Colors.NC}") + + +def print_error(msg: str): + print(f"{Colors.RED}{msg}{Colors.NC}", file=sys.stderr) + + +# ============================================================================= +# Compiler Detection +# ============================================================================= + + +def find_hipcc() -> str: + """Find hipcc compiler.""" + candidates = [ + os.environ.get("HIPCC"), + "/opt/rocm/bin/hipcc", + "/opt/rocm/hip/bin/hipcc", + shutil.which("hipcc"), + ] + + for path in candidates: + if path and os.path.isfile(path): + return path + + raise RuntimeError( + "hipcc not found. Please install ROCm or set HIPCC environment variable." + ) + + +# ============================================================================= +# Declaration Extraction +# ============================================================================= + + +def extract_conv_kernel_declarations(source_file: Path) -> list: + """Extract CONVOLUTION kernel declarations from C++ source file. + + Supports DECL_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern. + Extracts all parameters: dtype, layout, conv_type, dims, tile, wave, warp, pipeline, scheduler. + """ + content = source_file.read_text() + declarations = [] + seen = set() + + # Pattern: DECL_CONV_KERNEL_SET(name, .add(...).add(...)) + set_pattern = r"DECL_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)" + + for match in re.finditer(set_pattern, content, re.DOTALL): + set_name = match.group(1) + set_body = match.group(2) + + # Pattern 1: Simple add("dtype", "layout", "conv_type", tile_k, tile_c) + simple_add = ( + r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)' + ) + for add_match in re.finditer(simple_add, set_body): + dtype = add_match.group(1) + layout = add_match.group(2) + conv_type = add_match.group(3) + tile_k = int(add_match.group(4)) + tile_c = int(add_match.group(5)) + + name = f"{set_name}:{dtype}_{layout}_{conv_type}_{tile_k}x{tile_c}" + if name not in seen: + seen.add(name) + declarations.append( + { + "type": "conv", + "dtype": dtype, + "layout": layout, + "conv_type": conv_type, + "num_dims": 2, + "groups": 1, + "tile_n": 1, + "tile_k": tile_k, + "tile_c": tile_c, + "wave_m": -1, # Wildcard - will expand + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "set": set_name, + "arch": "gfx942", + } + ) + + # Pattern 2: Full specification with ConvSig() and ConvAlgo() + # Match .add( ConvSig()..., ConvAlgo()..., "arch" ) + # Use robust parsing that handles multi-line and comments + + # Find all .add( blocks containing ConvSig + add_blocks = re.findall( + r"\.add\s*\(\s*ConvSig\(\)([\s\S]*?)(?=\.add\s*\(|$)", set_body + ) + + for add_block in add_blocks: + # Find ConvAlgo and arch in this block + algo_match = re.search(r'ConvAlgo\(\)([\s\S]*?),\s*"(\w+)"\s*\)', add_block) + if not algo_match: + continue + + sig_str = add_block[: add_block.find("ConvAlgo()")] + algo_str = algo_match.group(1) + arch = algo_match.group(2) + + # Parse ConvSig + dtype = "fp16" + dtype_match = re.search(r'\.dtype\s*\(\s*"([^"]+)"', sig_str) + if dtype_match: + dtype = dtype_match.group(1) + + layout = "nhwgc" + layout_match = re.search(r'\.layout\s*\(\s*"([^"]+)"', sig_str) + if layout_match: + layout = layout_match.group(1) + + conv_type = "forward" + conv_type_match = re.search(r'\.conv_type\s*\(\s*"([^"]+)"', sig_str) + if conv_type_match: + conv_type = conv_type_match.group(1) + + num_dims = 2 + dims_match = re.search(r"\.dims\s*\(\s*(\d+)", sig_str) + if dims_match: + num_dims = int(dims_match.group(1)) + + groups = 1 + groups_match = re.search(r"\.groups\s*\(\s*(\d+)", sig_str) + if groups_match: + groups = int(groups_match.group(1)) + + # Parse ConvAlgo + tile_n, tile_k, tile_c = 1, 128, 128 + tile_match = re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", algo_str + ) + if tile_match: + tile_n = int(tile_match.group(1)) + tile_k = int(tile_match.group(2)) + tile_c = int(tile_match.group(3)) + + wave_m, wave_n, wave_k = 2, 2, 1 + wave_match = re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if wave_match: + wave_m = int(wave_match.group(1)) + wave_n = int(wave_match.group(2)) + wave_k = int(wave_match.group(3) or 1) + + warp_m, warp_n, warp_k = 32, 32, 16 + warp_match = re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if warp_match: + warp_m = int(warp_match.group(1)) + warp_n = int(warp_match.group(2)) + warp_k = int(warp_match.group(3) or 16) + + pipeline = "compv3" + pipeline_match = re.search(r'\.pipeline\s*\(\s*"([^"]+)"', algo_str) + if pipeline_match: + pipeline = pipeline_match.group(1) + + scheduler = "intrawave" + scheduler_match = re.search(r'\.scheduler\s*\(\s*"([^"]+)"', algo_str) + if scheduler_match: + scheduler = scheduler_match.group(1) + + epilogue = "cshuffle" + epilogue_match = re.search(r'\.epilogue\s*\(\s*"([^"]+)"', algo_str) + if epilogue_match: + epilogue = epilogue_match.group(1) + + # Build unique name with full config + name = f"{set_name}:{dtype}_{conv_type}_{num_dims}d_{pipeline}_{scheduler}_{tile_k}x{tile_c}_{wave_m}x{wave_n}x{wave_k}" + if name not in seen: + seen.add(name) + declarations.append( + { + "type": "conv", + "dtype": dtype, + "layout": layout, + "conv_type": conv_type, + "num_dims": num_dims, + "groups": groups, + "tile_n": tile_n, + "tile_k": tile_k, + "tile_c": tile_c, + "wave_m": wave_m, + "wave_n": wave_n, + "wave_k": wave_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "pipeline": pipeline, + "scheduler": scheduler, + "epilogue": epilogue, + "name": name, + "set": set_name, + "arch": arch, + } + ) + + return declarations + + +def expand_conv_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") -> list: + """Expand a convolution declaration to all valid combinations. + + Like GEMM, convolution supports wildcard expansion for: + - wave/warp: If -1, generates all valid combinations + - pipeline/scheduler: If "*", generates all valid trait combinations + """ + # Import arch filter + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + TRAIT_UNSUPPORTED_COMBINATIONS, + ) + except ImportError: + # Fallback + WARP_SUPPORTED_COMBINATIONS = { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + } + WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + } + TRAIT_UNSUPPORTED_COMBINATIONS = set() + + d = decl.copy() + tile_k = d.get("tile_k", 128) + tile_c = d.get("tile_c", 128) + dtype = d.get("dtype", "fp16") + + # Check what needs expansion + needs_wave_expansion = d.get("wave_m", -1) < 0 or d.get("wave_n", -1) < 0 + needs_warp_expansion = d.get("warp_m", -1) < 0 or d.get("warp_n", -1) < 0 + needs_pipeline_expansion = d.get("pipeline", "compv4") == "*" + needs_scheduler_expansion = d.get("scheduler", "intrawave") == "*" + + if ( + not needs_wave_expansion + and not needs_warp_expansion + and not needs_pipeline_expansion + and not needs_scheduler_expansion + ): + return [d] + + # Build valid combinations + if needs_wave_expansion or needs_warp_expansion: + wave_configs = WARP_SUPPORTED_COMBINATIONS.get(arch, [[2, 2, 1]]) + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_configs = WARP_TILE_SUPPORTED_COMBINATIONS.get(arch, {}).get( + dtype_key, [[32, 32, 16], [16, 16, 16]] + ) + else: + wave_configs = [[d.get("wave_m", 2), d.get("wave_n", 2), d.get("wave_k", 1)]] + warp_tile_configs = [ + [d.get("warp_m", 32), d.get("warp_n", 32), d.get("warp_k", 16)] + ] + + # Pipeline/scheduler combinations + ALL_PIPELINES = ["compv3", "compv4"] + ALL_SCHEDULERS = ["intrawave", "interwave"] + + pipelines = ( + ALL_PIPELINES if needs_pipeline_expansion else [d.get("pipeline", "compv4")] + ) + schedulers = ( + ALL_SCHEDULERS + if needs_scheduler_expansion + else [d.get("scheduler", "intrawave")] + ) + + expanded = [] + + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_tile_configs: + # Check divisibility for conv (M=output spatial, N=K channels, K=C channels) + # Simplified check for now + if tile_k % (wn * wtn) != 0: + continue + if tile_c % (wk * wtk) != 0: + continue + + for pipeline in pipelines: + for scheduler in schedulers: + # Check trait combination + if ( + pipeline, + "cshuffle", + scheduler, + ) in TRAIT_UNSUPPORTED_COMBINATIONS: + continue + + expanded_d = d.copy() + expanded_d["wave_m"] = wm + expanded_d["wave_n"] = wn + expanded_d["wave_k"] = wk + expanded_d["warp_m"] = wtm + expanded_d["warp_n"] = wtn + expanded_d["warp_k"] = wtk + expanded_d["pipeline"] = pipeline + expanded_d["scheduler"] = scheduler + + expanded_d["name"] = ( + f"conv_{d['conv_type']}_{dtype}_{d['num_dims']}d_{pipeline}_" + f"{scheduler}_{tile_k}x{tile_c}_{wm}x{wn}x{wk}" + ) + expanded.append(expanded_d) + + if not expanded: + # Fallback to defaults + d["wave_m"] = 2 + d["wave_n"] = 2 + d["wave_k"] = 1 + d["warp_m"] = 32 + d["warp_n"] = 32 + d["warp_k"] = 16 + d["pipeline"] = "compv4" + d["scheduler"] = "intrawave" + return [d] + + return expanded + + +def generate_conv_kernels(declarations: list, gpu_target: str = "gfx942") -> int: + """Generate convolution kernels using unified_conv_codegen.""" + kernel_dir = get_generated_kernels_dir() + kernel_dir.mkdir(parents=True, exist_ok=True) + + # Import conv codegen + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from unified_conv_codegen import ( + UnifiedConvCodegen, + ConvKernelConfig, + ConvVariant, + TileConfig, + TraitConfig, + ) + except ImportError as e: + print_error(f" Failed to import conv codegen: {e}") + return 0 + + codegen = UnifiedConvCodegen(kernel_dir) + total_generated = 0 + + # Group by dtype and variant for efficient generation + groups = {} + for decl in declarations: + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + key = (dtype, conv_type, num_dims) + if key not in groups: + groups[key] = [] + groups[key].append(decl) + + for (dtype, conv_type, num_dims), decls in groups.items(): + print(f" Generating {dtype} {conv_type} {num_dims}D kernels...") + + # Map to ConvVariant + variant = ConvVariant.FORWARD + if conv_type == "bwd_data": + variant = ConvVariant.BACKWARD_DATA + elif conv_type == "bwd_weight": + variant = ConvVariant.BACKWARD_WEIGHT + + for decl in decls: + pipeline = decl.get("pipeline", "compv3") + scheduler = decl.get("scheduler", "intrawave") + epilogue = decl.get("epilogue", "cshuffle") + + tile_k = decl.get("tile_k", 128) + tile_c = decl.get("tile_c", 128) + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + # Adjust tile_k for compv4 + adj_tile_k = 64 * 2 if pipeline == "compv4" else 64 + + # Create TileConfig + tile_config = TileConfig( + tile_m=tile_k, # K is M in conv GEMM view + tile_n=tile_c, # C is N in conv GEMM view + tile_k=adj_tile_k, + warp_m=wave_m, + warp_n=wave_n, + warp_k=1, + warp_tile_m=warp_m, + warp_tile_n=warp_n, + warp_tile_k=warp_k, + ) + + # Create TraitConfig + trait_config = TraitConfig( + pipeline=pipeline, + scheduler=scheduler, + epilogue=epilogue, + double_smem_buffer=(pipeline == "compv4"), + pad_m=True, + pad_n=True, + pad_k=True, + ) + + # Create ConvKernelConfig + config = ConvKernelConfig( + tile=tile_config, + trait=trait_config, + variant=variant, + ndim_spatial=num_dims, + arch=gpu_target, + ) + + try: + filepath = codegen.generate_kernel(config, dtype) + total_generated += 1 + print(f" Generated: {filepath.name}") + except Exception as e: + print_error(f" Failed to generate {decl['name']}: {e}") + + return total_generated + + +# Original GEMM extraction continues here +def extract_kernel_declarations(source_file: Path) -> list: + """Extract GEMM kernel declarations from C++ source file.""" + content = source_file.read_text() + declarations = [] + seen = set() + + # ------------------------------------------------------------------------- + # Pattern 1: Simple DECL_KERNEL_SIMPLE(dtype, layout, tile_m, tile_n, tile_k) + # ------------------------------------------------------------------------- + legacy_pattern = r"DECL_KERNEL_SIMPLE\s*\(\s*(\w+)\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)" + for match in re.findall(legacy_pattern, content): + dtype, layout, tm, tn, tk = match + name = f"{dtype}_{layout}_{tm}x{tn}x{tk}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": int(tm), + "tile_n": int(tn), + "tile_k": int(tk), + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": False, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 2: Fluent API: DECL_KERNEL(Signature()..., Algorithm()..., arch) + # ------------------------------------------------------------------------- + # Match DECL_KERNEL( ... ); blocks + fluent_pattern = r'DECL_KERNEL\s*\(\s*(Signature\(\)[^,]+),\s*(Algorithm\(\)[^,]+)(?:,\s*"([^"]+)")?\s*\)' + + for match in re.finditer(fluent_pattern, content, re.DOTALL): + sig_str = match.group(1) + algo_str = match.group(2) + arch = match.group(3) or "gfx942" + + # Parse Signature + sig = {"dtype_a": "fp16", "dtype_b": "fp16", "dtype_c": "fp16", "layout": "rcr"} + + # .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16") + dtype_match = re.search( + r'\.dtype\("([^"]+)"(?:,\s*"([^"]+)")?(?:,\s*"([^"]+)")?', sig_str + ) + if dtype_match: + sig["dtype_a"] = dtype_match.group(1) + sig["dtype_b"] = dtype_match.group(2) or dtype_match.group(1) + sig["dtype_c"] = dtype_match.group(3) or dtype_match.group(1) + + # .layout("rcr") or .layout("row", "col", "row") + layout_match = re.search( + r'\.layout\("([^"]+)"(?:,\s*"([^"]+)")?(?:,\s*"([^"]+)")?', sig_str + ) + if layout_match: + if layout_match.group(2): # Three-arg form + la = layout_match.group(1) + lb = layout_match.group(2) + lc = layout_match.group(3) or "row" + sig["layout"] = ( + ("r" if la == "row" else "c") + + ("r" if lb == "row" else "c") + + ("r" if lc == "row" else "c") + ) + else: # Single arg "rcr" + sig["layout"] = layout_match.group(1) + + # Parse Algorithm + algo = {} + + # .tile(128, 128, 32) + tile_match = re.search(r"\.tile\((\d+),\s*(\d+),\s*(\d+)\)", algo_str) + if tile_match: + algo["tile_m"] = int(tile_match.group(1)) + algo["tile_n"] = int(tile_match.group(2)) + algo["tile_k"] = int(tile_match.group(3)) + + # .wave(2, 2, 1) + wave_match = re.search(r"\.wave\((\d+),\s*(\d+)(?:,\s*(\d+))?\)", algo_str) + if wave_match: + algo["wave_m"] = int(wave_match.group(1)) + algo["wave_n"] = int(wave_match.group(2)) + algo["wave_k"] = int(wave_match.group(3) or 1) + + # .warp(32, 32, 16) + warp_match = re.search(r"\.warp\((\d+),\s*(\d+)(?:,\s*(\d+))?\)", algo_str) + if warp_match: + algo["warp_m"] = int(warp_match.group(1)) + algo["warp_n"] = int(warp_match.group(2)) + algo["warp_k"] = int(warp_match.group(3) or 16) + + # .pipeline("compv4"), .scheduler("intrawave"), .epilogue("cshuffle") + for field in ["pipeline", "scheduler", "epilogue"]: + fmatch = re.search(rf'\.{field}\("([^"]+)"\)', algo_str) + if fmatch: + algo[field] = fmatch.group(1) + + # Build declaration + tm = algo.get("tile_m", 128) + tn = algo.get("tile_n", 128) + tk = algo.get("tile_k", 32) + + name = f"{sig['dtype_a']}_{sig['layout']}_{tm}x{tn}x{tk}" + + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": sig["dtype_a"], + "dtype_b": sig["dtype_b"], + "dtype_c": sig["dtype_c"], + "layout": sig["layout"], + "tile_m": tm, + "tile_n": tn, + "tile_k": tk, + "wave_m": algo.get("wave_m", -1), + "wave_n": algo.get("wave_n", -1), + "wave_k": algo.get("wave_k", 1), + "warp_m": algo.get("warp_m", -1), + "warp_n": algo.get("warp_n", -1), + "warp_k": algo.get("warp_k", 16), + "pipeline": algo.get("pipeline", "compv4"), + "scheduler": algo.get("scheduler", "intrawave"), + "epilogue": algo.get("epilogue", "cshuffle"), + "arch": arch, + "name": name, + "wildcard": False, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 3: DECL_KERNEL_ALL(dtype, layout) - wildcard + # ------------------------------------------------------------------------- + all_pattern = r"DECL_KERNEL(?:S)?_ALL\s*\(\s*(\w+)\s*,\s*(\w+)\s*\)" + for match in re.findall(all_pattern, content): + dtype, layout = match + name = f"wildcard_{dtype}_{layout}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": -1, + "tile_n": -1, + "tile_k": -1, + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": True, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 4: DECL_KERNEL_SIMPLE(dtype, layout, tm, tn, tk) + # ------------------------------------------------------------------------- + simple_pattern = r"DECL_KERNEL_SIMPLE\s*\(\s*(\w+)\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)" + for match in re.findall(simple_pattern, content): + dtype, layout, tm, tn, tk = match + name = f"{dtype}_{layout}_{tm}x{tn}x{tk}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": int(tm), + "tile_n": int(tn), + "tile_k": int(tk), + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": False, + "set": None, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 5: DECL_KERNEL_SET(name, .add(...).add(...)) + # Named kernel sets for multiple registries + # Match only DECL_KERNEL_SET at start of line (not in comments) + # ------------------------------------------------------------------------- + set_pattern = r"^DECL_KERNEL_SET\s*\(\s*(\w+)\s*,([\s\S]*?)\)\s*;" + for match in re.finditer(set_pattern, content, re.MULTILINE): + set_name = match.group(1) + set_body = match.group(2) + + # Parse .add("dtype", "layout", tm, tn, tk) calls - simple form + add_simple = r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)' + for add_match in re.findall(add_simple, set_body): + dtype, layout, tm, tn, tk = add_match + name = f"{set_name}:{dtype}_{layout}_{tm}x{tn}x{tk}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": int(tm), + "tile_n": int(tn), + "tile_k": int(tk), + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": False, + "set": set_name, + } + ) + + # Parse .add(Signature()..., Algorithm()..., "arch") fluent calls + # Robust approach: find each .add( block and parse methods individually + # This handles any method order and optional methods + + # Split set_body into .add() blocks + add_blocks = [] + add_starts = [m.start() for m in re.finditer(r"\.add\s*\(", set_body)] + + for i, start in enumerate(add_starts): + # Find the matching closing paren by counting parens + depth = 0 + end = start + in_string = False + escape_next = False + + for j, ch in enumerate(set_body[start:], start): + if escape_next: + escape_next = False + continue + if ch == "\\": + escape_next = True + continue + if ch == '"' and not escape_next: + in_string = not in_string + continue + if in_string: + continue + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + end = j + 1 + break + + if end > start: + add_blocks.append(set_body[start:end]) + + for add_block in add_blocks: + # Skip if doesn't have both Signature() and Algorithm() + if "Signature()" not in add_block or "Algorithm()" not in add_block: + continue + + # Split on Algorithm() to separate Signature and Algorithm parts + algo_idx = add_block.find("Algorithm()") + if algo_idx == -1: + continue + + sig_str = add_block[:algo_idx] + algo_str = add_block[algo_idx:] # Include Algorithm() and everything after + + # Parse dtype from Signature - handles .dtype("fp16", "fp16", "fp16", "fp32") + dtype = "fp16" + dtype_m = re.search(r'\.dtype\s*\(\s*"([^"]+)"', sig_str) + if dtype_m: + dtype = dtype_m.group(1) + + # Parse layout from Signature - handles .layout("row", "col", "row") + layout = "rcr" + layout_m = re.search( + r'\.layout\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"', sig_str + ) + if layout_m: + la, lb, lc = layout_m.group(1), layout_m.group(2), layout_m.group(3) + layout = ( + ("r" if la == "row" else "c") + + ("r" if lb == "row" else "c") + + ("r" if lc == "row" else "c") + ) + else: + # Single arg form: .layout("rcr") + layout_m = re.search(r'\.layout\s*\(\s*"([^"]+)"', sig_str) + if layout_m: + layout = layout_m.group(1) + + # Parse tile from Algorithm + tm, tn, tk = 128, 128, 32 + tile_m = re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", algo_str + ) + if tile_m: + tm, tn, tk = ( + int(tile_m.group(1)), + int(tile_m.group(2)), + int(tile_m.group(3)), + ) + + # Parse wave + wave_m, wave_n, wave_k = 2, 2, 1 + wave_match = re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if wave_match: + wave_m, wave_n = int(wave_match.group(1)), int(wave_match.group(2)) + wave_k = int(wave_match.group(3) or 1) + + # Parse warp + warp_m, warp_n, warp_k = 32, 32, 16 + warp_match = re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if warp_match: + warp_m, warp_n = int(warp_match.group(1)), int(warp_match.group(2)) + warp_k = int(warp_match.group(3) or 16) + + # Parse pipeline - NEW: extract from declaration + pipeline = "compv4" + pipeline_m = re.search(r'\.pipeline\s*\(\s*"([^"]+)"', algo_str) + if pipeline_m: + pipeline = pipeline_m.group(1) + + # Parse scheduler - NEW: extract from declaration + scheduler = "intrawave" + scheduler_m = re.search(r'\.scheduler\s*\(\s*"([^"]+)"', algo_str) + if scheduler_m: + scheduler = scheduler_m.group(1) + + # Parse epilogue - NEW: extract from declaration + epilogue = "cshuffle" + epilogue_m = re.search(r'\.epilogue\s*\(\s*"([^"]+)"', algo_str) + if epilogue_m: + epilogue = epilogue_m.group(1) + + # Parse padding - NEW: extract from declaration + pad_m, pad_n, pad_k = False, False, False + pad_match = re.search( + r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)\s*\)", + algo_str, + re.IGNORECASE, + ) + if pad_match: + pad_m = pad_match.group(1).lower() == "true" + pad_n = pad_match.group(2).lower() == "true" + pad_k = pad_match.group(3).lower() == "true" + + # Parse elementwise from Signature - for Multi-D kernels + elementwise_op = "PassThrough" + num_d_tensors = 0 + elem_match = re.search( + r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)\s*\)', + sig_str, + ) + if elem_match: + elementwise_op = elem_match.group(1) + num_d_tensors = int(elem_match.group(2)) + + name = f"{set_name}:{dtype}_{layout}_{pipeline}_{scheduler}_{tm}x{tn}x{tk}_{wave_m}x{wave_n}x{wave_k}" + if elementwise_op != "PassThrough": + name += f"_{elementwise_op}_d{num_d_tensors}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": tm, + "tile_n": tn, + "tile_k": tk, + "wave_m": wave_m, + "wave_n": wave_n, + "wave_k": wave_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "pipeline": pipeline, + "scheduler": scheduler, + "epilogue": epilogue, + "pad_m": pad_m, + "pad_n": pad_n, + "pad_k": pad_k, + "elementwise_op": elementwise_op, + "num_d_tensors": num_d_tensors, + "name": name, + "wildcard": False, + "set": set_name, + } + ) + + return declarations + + +def expand_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") -> list: + """Expand a declaration to all valid combinations using arch filter. + + Expands wildcards for: + - wave/warp: If -1, generates all valid wave/warp_tile combinations + - pipeline/scheduler/epilogue: If "*", generates all valid trait combinations + + Uses the arch_filter module for architecture-specific validation. + """ + # Import arch filter + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + TRAIT_UNSUPPORTED_COMBINATIONS, + ) + except ImportError: + # Fallback to hardcoded valid combinations + WARP_SUPPORTED_COMBINATIONS = { + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + } + WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + } + TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + } + + d = decl.copy() + tm = d.get("tile_m", 128) + tn = d.get("tile_n", 128) + tk = d.get("tile_k", 32) + dtype = d.get("dtype_a", "fp16") + + # Check what needs expansion + needs_wave_expansion = d.get("wave_m", -1) < 0 or d.get("wave_n", -1) < 0 + needs_warp_expansion = d.get("warp_m", -1) < 0 or d.get("warp_n", -1) < 0 + needs_pipeline_expansion = d.get("pipeline", "compv4") == "*" + needs_scheduler_expansion = d.get("scheduler", "intrawave") == "*" + needs_epilogue_expansion = d.get("epilogue", "cshuffle") == "*" + needs_pad_m_expansion = d.get("pad_m", 1) == -1 + needs_pad_n_expansion = d.get("pad_n", 1) == -1 + needs_pad_k_expansion = d.get("pad_k", 1) == -1 + needs_trait_expansion = ( + needs_pipeline_expansion + or needs_scheduler_expansion + or needs_epilogue_expansion + ) + needs_pad_expansion = ( + needs_pad_m_expansion or needs_pad_n_expansion or needs_pad_k_expansion + ) + + if ( + not needs_wave_expansion + and not needs_warp_expansion + and not needs_trait_expansion + and not needs_pad_expansion + ): + # Already fully specified + return [d] + + # === Build valid combinations === + + # Wave configurations + if needs_wave_expansion: + wave_configs = WARP_SUPPORTED_COMBINATIONS.get(arch, [[2, 2, 1]]) + else: + wave_configs = [[d.get("wave_m", 2), d.get("wave_n", 2), d.get("wave_k", 1)]] + + # Warp tile configurations + if needs_warp_expansion: + arch_warp_tiles = WARP_TILE_SUPPORTED_COMBINATIONS.get(arch, {}) + + # Try to find warp tile configs for this dtype + # Keys are like: fp16_fp16_fp32, int8_int8_int32, etc. + warp_tile_configs = None + dtype_key_variants = [ + f"{dtype}_{dtype}_{dtype}", # e.g., fp32_fp32_fp32 + f"{dtype}_{dtype}_fp32", # e.g., fp16_fp16_fp32 + f"{dtype}_{dtype}_int32", # e.g., int8_int8_int32 + ] + for dtype_key in dtype_key_variants: + warp_tile_configs = arch_warp_tiles.get(dtype_key, None) + if warp_tile_configs is not None: + break + + # If dtype is not supported on this arch, return empty list + if warp_tile_configs is None: + return [] + else: + warp_tile_configs = [ + [d.get("warp_m", 32), d.get("warp_n", 32), d.get("warp_k", 16)] + ] + + # Pipeline/scheduler/epilogue combinations + # Valid options per category + ALL_PIPELINES = ["compv3", "compv4"] # Most common; add more if needed + ALL_SCHEDULERS = ["intrawave", "interwave"] + ALL_EPILOGUES = ["cshuffle", "default"] + ALL_PAD_OPTIONS = [False, True] # 0 and 1 + + pipelines = ( + ALL_PIPELINES if needs_pipeline_expansion else [d.get("pipeline", "compv4")] + ) + schedulers = ( + ALL_SCHEDULERS + if needs_scheduler_expansion + else [d.get("scheduler", "intrawave")] + ) + epilogues = ( + ALL_EPILOGUES if needs_epilogue_expansion else [d.get("epilogue", "cshuffle")] + ) + pad_m_opts = ALL_PAD_OPTIONS if needs_pad_m_expansion else [bool(d.get("pad_m", 1))] + pad_n_opts = ALL_PAD_OPTIONS if needs_pad_n_expansion else [bool(d.get("pad_n", 1))] + pad_k_opts = ALL_PAD_OPTIONS if needs_pad_k_expansion else [bool(d.get("pad_k", 1))] + + expanded = [] + + # Generate all valid combinations + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_tile_configs: + # Check divisibility constraints + if tm % (wm * wtm) != 0: + continue + if tn % (wn * wtn) != 0: + continue + if tk % (wk * wtk) != 0: + continue + + for pipeline in pipelines: + for scheduler in schedulers: + for epilogue in epilogues: + # Check trait combination is valid + if ( + pipeline, + epilogue, + scheduler, + ) in TRAIT_UNSUPPORTED_COMBINATIONS: + continue + + for pad_m in pad_m_opts: + for pad_n in pad_n_opts: + for pad_k in pad_k_opts: + # Create expanded declaration + expanded_d = d.copy() + expanded_d["wave_m"] = wm + expanded_d["wave_n"] = wn + expanded_d["wave_k"] = wk + expanded_d["warp_m"] = wtm + expanded_d["warp_n"] = wtn + expanded_d["warp_k"] = wtk + expanded_d["pipeline"] = pipeline + expanded_d["scheduler"] = scheduler + expanded_d["epilogue"] = epilogue + expanded_d["pad_m"] = int(pad_m) + expanded_d["pad_n"] = int(pad_n) + expanded_d["pad_k"] = int(pad_k) + + pad_str = f"{'T' if pad_m else 'F'}{'T' if pad_n else 'F'}{'T' if pad_k else 'F'}" + expanded_d["name"] = ( + f"{dtype}_{d.get('layout', 'rcr')}_{pipeline}_{scheduler}_" + f"pad{pad_str}_{tm}x{tn}x{tk}_{wm}x{wn}x{wk}" + ) + expanded_d["wildcard"] = False + expanded.append(expanded_d) + + if not expanded: + # No valid combinations found, return single default + print(f" Warning: No valid combinations for {tm}x{tn}x{tk} on {arch}") + d["wave_m"] = 2 + d["wave_n"] = 2 + d["wave_k"] = 1 + d["warp_m"] = 32 + d["warp_n"] = 32 + d["warp_k"] = 16 + d["pipeline"] = "compv4" + d["scheduler"] = "intrawave" + d["epilogue"] = "cshuffle" + return [d] + + return expanded + + +def auto_fill_declaration(decl: dict) -> dict: + """Auto-fill with single default (for backward compat).""" + expanded = expand_declaration_with_arch_filter(decl, decl.get("arch", "gfx942")) + return expanded[0] if expanded else decl + + +# ============================================================================= +# Build Functions +# ============================================================================= + + +def generate_kernels(declarations: list, gpu_target: str = "gfx942") -> int: + """Generate kernels using CodegenRunner from ctypes_utils.""" + kernel_dir = get_generated_kernels_dir() + kernel_dir.mkdir(parents=True, exist_ok=True) + + # Group by dtype+layout for efficient generation + groups = {} + for decl in declarations: + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + key = (dtype, layout) + if key not in groups: + groups[key] = [] + groups[key].append(auto_fill_declaration(decl)) + + total_generated = 0 + + for (dtype, layout), decls in groups.items(): + print(f" Generating {dtype} {layout} kernels...") + + # Check for wildcards - if any decl is wildcard, generate all + has_wildcard = any(d.get("wildcard", False) for d in decls) + + # Use CodegenRunner from ctypes_utils + runner = CodegenRunner( + datatype=dtype, + layout=layout, + gpu_target=gpu_target, + ) + + result = runner.generate("standard") + + if result.success: + total_generated += result.kernel_count + if has_wildcard: + print(f" [wildcard] Generated all {result.kernel_count} variants") + else: + print_error(f" Failed: {result.stderr[:200]}") + + return total_generated + + +def get_arch_filter_data(): + """Load arch filter data from arch_specs_generated if available.""" + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + TRAIT_UNSUPPORTED_COMBINATIONS, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + return { + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + # Fallback defaults + return { + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + }, + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + + +def is_wildcard_declaration(decl: dict) -> bool: + """Check if declaration has wildcards that need expansion.""" + # Wave/warp wildcards + if decl.get("wave_m", 2) < 0 or decl.get("wave_n", 2) < 0: + return True + if decl.get("warp_m", 32) < 0 or decl.get("warp_n", 32) < 0: + return True + # Pipeline/scheduler wildcards + if decl.get("pipeline", "compv4") == "*": + return True + if decl.get("scheduler", "intrawave") == "*": + return True + if decl.get("epilogue", "cshuffle") == "*": + return True + return False + + +def validate_kernel_config(decl: dict, arch: str = "gfx942") -> tuple: + """Validate a kernel configuration against known supported combinations. + + Uses arch_specs_generated for architecture-specific validation. + + For wildcard declarations (-1 values or "*" strings), validation is skipped + because the expansion phase will generate only valid combinations. + + Returns: (is_valid, error_message) + """ + # Skip validation for wildcards - expansion will filter invalid combos + if is_wildcard_declaration(decl): + return (True, None) + + arch_data = get_arch_filter_data() + + pipeline = decl.get("pipeline", "compv4") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + dtype = decl.get("dtype_a", "fp16") + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + errors = [] + + # Check trait combination (pipeline, epilogue, scheduler) + combo = (pipeline, epilogue, scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}\n" + f" Valid schedulers for {pipeline}+{epilogue}: intrawave" + ) + + # Check wave configuration for this arch + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n" + f" Valid wave configs: {valid_str}" + ) + + # Check warp tile configuration for this arch and dtype + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n" + f" Valid warp tiles: {valid_str}" + ) + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}\n" + f" Supported: {', '.join(arch_data['supported_archs'])}" + ) + + if errors: + return (False, "\n".join(errors)) + + return (True, None) + + +def build_exact_kernel_filename(decl: dict) -> str: + """Build the exact kernel filename from a fully-specified declaration. + + Standard format: + gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile}_{wave}_{warp}.hpp + + Multi-D format: + gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile}_{wave}_{warp}_multid_{op}_d{num}.hpp + """ + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + pipeline = decl.get("pipeline", "compv4") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + + pad_m = "True" if decl.get("pad_m", False) else "False" + pad_n = "True" if decl.get("pad_n", False) else "False" + pad_k = "True" if decl.get("pad_k", False) else "False" + preshuffle = "True" if decl.get("preshuffle", False) else "False" + + tile_m = decl.get("tile_m", 128) + tile_n = decl.get("tile_n", 128) + tile_k = decl.get("tile_k", 32) + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + tile_str = f"{tile_m}x{tile_n}x{tile_k}" + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + warp_str = f"{warp_m}x{warp_n}x{warp_k}" + + base = f"gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile_str}_{wave_str}_{warp_str}" + + # Handle Multi-D kernels + elementwise_op = decl.get("elementwise_op", "PassThrough") + num_d_tensors = decl.get("num_d_tensors", 0) + if elementwise_op != "PassThrough" and num_d_tensors > 0: + base += f"_multid_{elementwise_op}_d{num_d_tensors}" + + return f"{base}.hpp" + + +def generate_specific_kernel(decl: dict, gpu_target: str = "gfx942") -> bool: + """Generate a specific kernel based on declaration.""" + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + + print(f" Generating kernel for {dtype}/{layout}...") + + # Use CodegenRunner to generate + runner = CodegenRunner( + datatype=dtype, + layout=layout, + gpu_target=gpu_target, + ) + + result = runner.generate("standard") + return result.success + + +def find_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path: + """Find a matching kernel header file for a declaration. + + Tries multiple matching strategies: + 1. Exact filename match + 2. Match with key parameters (dtype, layout, pipeline, scheduler, tile) + 3. Match with just dtype, layout, and tile (more flexible) + 4. Any kernel with matching dtype and layout + + If no kernel exists, attempts to generate it. + Returns None only if all strategies fail. + """ + kernel_dir = get_generated_kernels_dir() + + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + pipeline = decl.get("pipeline", "compv4") + scheduler = decl.get("scheduler", "intrawave") + tile_m = decl.get("tile_m", 128) + tile_n = decl.get("tile_n", 128) + tile_k = decl.get("tile_k", 32) + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + tile_str = f"{tile_m}x{tile_n}x{tile_k}" + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + + # Build exact filename + exact_filename = build_exact_kernel_filename(decl) + exact_path = kernel_dir / exact_filename + + # Strategy 1: Exact filename match + if exact_path.exists(): + print(f" Found exact kernel: {exact_filename}") + return exact_path + + # Strategy 2: Match with key parameters + pattern = ( + f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_*.hpp" + ) + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found matching kernel: {matches[0].name}") + return matches[0] + + # Strategy 3: Match with just dtype, layout, tile (ignore wave/warp) + pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found kernel with matching tile: {matches[0].name}") + return matches[0] + + # Strategy 4: Match with just dtype, layout (most flexible, for wildcards) + # Prefer kernels with intrawave scheduler (known to work) + pattern = f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found kernel with intrawave: {matches[0].name}") + return matches[0] + + # Strategy 5: Any kernel with matching dtype and layout + pattern = f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found kernel with matching dtype/layout/tile: {matches[0].name}") + return matches[0] + + # Strategy 6: Try to generate the kernel + print(" No matching kernel found, attempting to generate...") + if generate_specific_kernel(decl, gpu_target): + # Check strategies again after generation + for pattern in [ + f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp", + f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp", + f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp", + ]: + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Generated: {matches[0].name}") + return matches[0] + + # All strategies failed - return None (caller will try next expanded decl) + return None + + +def is_conv_wildcard_declaration(decl: dict) -> bool: + """Check if conv declaration has wildcards that need expansion.""" + if decl.get("wave_m", 2) < 0 or decl.get("wave_n", 2) < 0: + return True + if decl.get("warp_m", 32) < 0 or decl.get("warp_n", 32) < 0: + return True + if decl.get("pipeline", "compv3") == "*": + return True + if decl.get("scheduler", "intrawave") == "*": + return True + return False + + +def validate_conv_kernel_config(decl: dict, arch: str = "gfx942") -> tuple: + """Validate a conv kernel configuration against arch filter. + + For wildcard declarations, validation is skipped (expansion handles it). + + Returns: (is_valid, error_message) + """ + # Skip validation for wildcards + if is_conv_wildcard_declaration(decl): + return (True, None) + + arch_data = get_arch_filter_data() + + pipeline = decl.get("pipeline", "compv3") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + dtype = decl.get("dtype", "fp16") + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + errors = [] + + # Check trait combination + combo = (pipeline, epilogue, scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}\n" + f" Valid schedulers for {pipeline}+{epilogue}: intrawave" + ) + + # Check wave configuration + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n" + f" Valid wave configs: {valid_str}" + ) + + # Check warp tile configuration + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n" + f" Valid warp tiles: {valid_str}" + ) + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}\n" + f" Supported: {', '.join(arch_data['supported_archs'])}" + ) + + if errors: + return (False, "\n".join(errors)) + + return (True, None) + + +def build_exact_conv_kernel_filename(decl: dict) -> str: + """Build the exact conv kernel filename from a fully-specified declaration. + + Conv filename format: + conv_{type}_{dtype}_{ndim}d_{pipeline}_{epilogue}_{scheduler}_{tile}_{wave}.hpp + + Example: + conv_fwd_fp16_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1.hpp + """ + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + pipeline = decl.get("pipeline", "compv3") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + + # Map conv_type to filename prefix + if conv_type == "forward": + type_prefix = "fwd" + elif conv_type == "bwd_data": + type_prefix = "bwdd" + elif conv_type == "bwd_weight": + type_prefix = "bwdw" + else: + type_prefix = conv_type + + tile_k = decl.get("tile_k", 128) + tile_c = decl.get("tile_c", 128) + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + tile_str = f"{tile_k}x{tile_c}x32" # Conv uses tile_k x tile_c x 32 format + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + + return f"conv_{type_prefix}_{dtype}_{num_dims}d_{pipeline}_{epilogue}_{scheduler}_{tile_str}_{wave_str}.hpp" + + +def generate_specific_conv_kernel(decl: dict, gpu_target: str = "gfx942") -> bool: + """Generate a specific conv kernel based on declaration.""" + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + + print(f" Generating conv kernel for {dtype}/{conv_type}/{num_dims}d...") + + # Map to variant name + if conv_type == "forward": + variant = "forward" + elif conv_type == "bwd_data": + variant = "bwd_data" + elif conv_type == "bwd_weight": + variant = "bwd_weight" + else: + variant = "forward" + + # Use unified_conv_codegen + codegen_dir = get_dispatcher_root() / "codegen" + codegen_script = codegen_dir / "unified_conv_codegen.py" + output_dir = get_generated_kernels_dir() + + cmd = [ + "python3", + str(codegen_script), + "--datatype", + dtype, + "--variant", + variant, + "--ndim", + str(num_dims), + "--arch", + gpu_target, + "--output", + str(output_dir), + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + return result.returncode == 0 + except subprocess.TimeoutExpired: + return False + + +def find_conv_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path: + """Find the EXACT matching conv kernel header file for a declaration. + + If the kernel doesn't exist, attempts to generate it. + Returns None only if generation also fails. + """ + kernel_dir = get_generated_kernels_dir() + + # Build exact filename + exact_filename = build_exact_conv_kernel_filename(decl) + exact_path = kernel_dir / exact_filename + + # Check if exact kernel exists + if exact_path.exists(): + print(f" Found exact conv kernel: {exact_filename}") + return exact_path + + # Try to find with glob (in case of minor variations) + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + pipeline = decl.get("pipeline", "compv3") + scheduler = decl.get("scheduler", "intrawave") + tile_k = decl.get("tile_k", 128) + tile_c = decl.get("tile_c", 128) + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + # Map conv_type to prefix + if conv_type == "forward": + type_prefix = "fwd" + elif conv_type == "bwd_data": + type_prefix = "bwdd" + elif conv_type == "bwd_weight": + type_prefix = "bwdw" + else: + type_prefix = conv_type + + tile_str = f"{tile_k}x{tile_c}" + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + + # Search pattern with key parameters + pattern = f"conv_{type_prefix}_{dtype}_{num_dims}d_{pipeline}_*_{scheduler}_*{tile_str}*_{wave_str}.hpp" + matches = list(kernel_dir.glob(pattern)) + + if matches: + print(f" Found matching conv kernel: {matches[0].name}") + return matches[0] + + # Kernel doesn't exist - try to generate it + print(f" Conv kernel not found: {exact_filename}") + print(" Attempting to generate...") + + if generate_specific_conv_kernel(decl, gpu_target): + # Check again after generation + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Generated: {matches[0].name}") + return matches[0] + + # Check for exact match + if exact_path.exists(): + print(f" Generated: {exact_filename}") + return exact_path + + # Still not found - print helpful error + print_error( + " ERROR: Could not find or generate conv kernel matching declaration:" + ) + print_error(f" dtype={dtype}, conv_type={conv_type}, num_dims={num_dims}") + print_error(f" pipeline={pipeline}, scheduler={scheduler}") + print_error(f" tile={tile_k}x{tile_c}, wave={wave_str}") + print_error(f" Expected: {exact_filename}") + print_error(f" Available conv kernels in {kernel_dir}:") + + available = list(kernel_dir.glob(f"conv_{type_prefix}_{dtype}_{num_dims}d_*.hpp"))[ + :5 + ] + for k in available: + print_error(f" - {k.name}") + if len(list(kernel_dir.glob(f"conv_{type_prefix}_{dtype}_{num_dims}d_*.hpp"))) > 5: + print_error(" ... and more") + + return None + + +def build_dispatcher_library(hipcc: str) -> bool: + """Build the dispatcher library if needed.""" + build_dir = get_build_dir() + lib_path = build_dir / "libck_tile_dispatcher.a" + + if lib_path.exists(): + return True + + print(" Building dispatcher library...") + build_dir.mkdir(parents=True, exist_ok=True) + + dispatcher_dir = get_dispatcher_root() + + # Run cmake + cmake_cmd = ["cmake", str(dispatcher_dir), f"-DCMAKE_CXX_COMPILER={hipcc}"] + result = subprocess.run( + cmake_cmd, cwd=str(build_dir), capture_output=True, text=True + ) + if result.returncode != 0: + print_error(f"CMake failed: {result.stderr}") + return False + + # Run make + make_cmd = ["make", "ck_tile_dispatcher", f"-j{os.cpu_count() or 4}"] + result = subprocess.run( + make_cmd, cwd=str(build_dir), capture_output=True, text=True + ) + if result.returncode != 0: + print_error(f"Make failed: {result.stderr}") + return False + + return True + + +def compile_application( + source_file: Path, + output_bin: Path, + kernel_header: Path, + hipcc: str, + gpu_target: str = "gfx942", +) -> bool: + """Compile the application with hipcc.""" + ck_root = get_ck_root() + dispatcher_dir = get_dispatcher_root() + build_dir = get_build_dir() + kernel_dir = get_generated_kernels_dir() + + includes = [ + f"-I{ck_root / 'include'}", + f"-I{dispatcher_dir / 'include'}", + f"-I{kernel_dir}", + ] + + cmd = [ + hipcc, + "-std=c++17", + "-O3", + f"--offload-arch={gpu_target}", + *includes, + "-include", + str(kernel_header), + f"-L{build_dir}", + "-lck_tile_dispatcher", + "-o", + str(output_bin), + str(source_file), + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + # Filter out nodiscard warnings + if result.stderr: + lines = result.stderr.split("\n") + errors = [line for line in lines if "error:" in line.lower()] + if errors: + for err_line in errors[:5]: + print_error(f" {err_line}") + + return result.returncode == 0 + + +# ============================================================================= +# Main +# ============================================================================= + + +def main(): + parser = argparse.ArgumentParser( + description="Build CK Tile application with declarative kernels", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Example: + python3 compile_gemm_examples.py examples/cpp/01_basic_gemm_declarative.cpp my_app + +In your C++ code, declare kernels like: + DECL_KERNEL_SET(my_kernels, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(128, 128, 32).wave(2, 2, 1).warp(32, 32, 16) + .pipeline("compv4").scheduler("intrawave")) + ); +""", + ) + parser.add_argument("source", help="Source file (.cpp)") + parser.add_argument( + "output", nargs="?", help="Output name (default: source basename)" + ) + parser.add_argument( + "--gpu-target", default="gfx942", help="GPU target architecture" + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + args = parser.parse_args() + + # Resolve paths using utilities from ctypes_utils + dispatcher_dir = get_dispatcher_root() + build_dir = get_build_dir() + + source_file = Path(args.source) + if not source_file.is_absolute(): + # Try relative to dispatcher dir first, then CWD + candidates = [ + dispatcher_dir / args.source, + dispatcher_dir / "examples" / args.source, # examples/gemm/cpp/... + Path.cwd() / args.source, + ] + for candidate in candidates: + if candidate.exists(): + source_file = candidate + break + + if not source_file.exists(): + print_error(f"Source file not found: {source_file}") + return 1 + + output_name = args.output or source_file.stem + output_bin = build_dir / output_name + + # Ensure build directory exists + build_dir.mkdir(parents=True, exist_ok=True) + + print_success("=== CK Tile Declarative Kernel Build ===") + print() + + # Phase 1: Extract declarations (both GEMM and Conv) + print_phase("Phase 1: Scanning for kernel declarations...") + + gemm_declarations = extract_kernel_declarations(source_file) + conv_declarations = extract_conv_kernel_declarations(source_file) + + if not gemm_declarations and not conv_declarations: + print_error(" No kernel declarations found!") + print(" Add DECL_KERNEL_SET for GEMM or DECL_CONV_KERNEL_SET for Conv") + return 1 + + # Handle GEMM declarations + if gemm_declarations: + print(f"\n GEMM: Found {len(gemm_declarations)} declaration(s)") + + # Group by kernel set + sets = {} + for decl in gemm_declarations: + set_name = decl.get("set") or "(global)" + if set_name not in sets: + sets[set_name] = [] + sets[set_name].append(decl) + + for set_name, set_decls in sets.items(): + print(f" [{set_name}] ({len(set_decls)} kernels):") + for decl in set_decls[:5]: + needs_expansion = ( + decl.get("wave_m", -1) < 0 or decl.get("warp_m", -1) < 0 + ) + suffix = " [expands]" if needs_expansion else "" + display_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + print(f" - {display_name}{suffix}") + if len(set_decls) > 5: + print(f" ... and {len(set_decls) - 5} more") + + # Validate declarations against arch filter + print(f"\n Validating against {args.gpu_target} arch filter...") + wildcard_count = 0 + invalid_count = 0 + auto_corrections = [] + + for decl in gemm_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + # Check for wildcards + if is_wildcard_declaration(decl): + wildcard_count += 1 + continue # Wildcards validated during expansion + + is_valid, error_msg = validate_kernel_config(decl, arch) + if not is_valid: + print(f"\n ⚠ Invalid configuration: {decl_name}") + + # Parse the error and show specific auto-corrections + corrections = [] + original_values = {} + + if "wave configuration" in error_msg.lower(): + original_values["wave"] = ( + f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]" + ) + decl["wave_m"] = -1 + decl["wave_n"] = -1 + corrections.append( + f"wave: {original_values['wave']} → [wildcard expansion]" + ) + + if "warp tile" in error_msg.lower(): + original_values["warp"] = ( + f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]" + ) + decl["warp_m"] = -1 + decl["warp_n"] = -1 + corrections.append( + f"warp_tile: {original_values['warp']} → [wildcard expansion]" + ) + + if "trait combination" in error_msg.lower(): + original_values["pipeline"] = decl.get("pipeline", "compv4") + original_values["scheduler"] = decl.get("scheduler", "intrawave") + decl["pipeline"] = "*" + decl["scheduler"] = "*" + corrections.append( + f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + ) + corrections.append( + f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + ) + + # Print the auto-corrections + print(" AUTO-CORRECTION:") + for corr in corrections: + print(f" • {corr}") + auto_corrections.append((decl_name, corrections)) + + invalid_count += 1 + wildcard_count += 1 + + if invalid_count > 0: + print( + f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + ) + + if wildcard_count > 0: + print( + f" ✓ {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + ) + else: + print(f" ✓ All {len(gemm_declarations)} configurations valid") + + # Expand GEMM declarations (for wildcards) + print("\n Expanding wildcards to valid configurations...") + expanded_gemm = [] + for decl in gemm_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + expanded = expand_declaration_with_arch_filter(decl, arch) + expanded_gemm.extend(expanded) + + # Show what the wildcard expanded to + if len(expanded) > 1: + print( + f" {decl_name}: expanded to {len(expanded)} valid configurations" + ) + # Show first few expanded configs + for exp in expanded[:3]: + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print( + f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + ) + if len(expanded) > 3: + print(f" ... and {len(expanded) - 3} more") + elif len(expanded) == 1 and is_wildcard_declaration(decl): + exp = expanded[0] + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + + if len(expanded_gemm) > len(gemm_declarations): + print( + f"\n Total: {len(gemm_declarations)} declarations → {len(expanded_gemm)} configurations" + ) + + gemm_declarations = expanded_gemm + + # Handle Conv declarations + if conv_declarations: + print(f"\n CONV: Found {len(conv_declarations)} declaration(s)") + + # Group by kernel set + sets = {} + for decl in conv_declarations: + set_name = decl.get("set") or "(global)" + if set_name not in sets: + sets[set_name] = [] + sets[set_name].append(decl) + + for set_name, set_decls in sets.items(): + print(f" [{set_name}] ({len(set_decls)} kernels):") + for decl in set_decls[:5]: + needs_expansion = is_conv_wildcard_declaration(decl) + suffix = " [expands]" if needs_expansion else "" + display_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + print(f" - {display_name}{suffix}") + if len(set_decls) > 5: + print(f" ... and {len(set_decls) - 5} more") + + # Validate Conv declarations against arch filter + print(f"\n Validating against {args.gpu_target} arch filter...") + wildcard_count = 0 + invalid_count = 0 + auto_corrections = [] + + for decl in conv_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + # Check for wildcards + if is_conv_wildcard_declaration(decl): + wildcard_count += 1 + continue # Wildcards validated during expansion + + is_valid, error_msg = validate_conv_kernel_config(decl, arch) + if not is_valid: + print(f"\n ⚠ Invalid conv configuration: {decl_name}") + + # Parse the error and show specific auto-corrections + corrections = [] + original_values = {} + + if "wave configuration" in error_msg.lower(): + original_values["wave"] = ( + f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]" + ) + decl["wave_m"] = -1 + decl["wave_n"] = -1 + corrections.append( + f"wave: {original_values['wave']} → [wildcard expansion]" + ) + + if "warp tile" in error_msg.lower(): + original_values["warp"] = ( + f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]" + ) + decl["warp_m"] = -1 + decl["warp_n"] = -1 + corrections.append( + f"warp_tile: {original_values['warp']} → [wildcard expansion]" + ) + + if "trait combination" in error_msg.lower(): + original_values["pipeline"] = decl.get("pipeline", "compv3") + original_values["scheduler"] = decl.get("scheduler", "intrawave") + decl["pipeline"] = "*" + decl["scheduler"] = "*" + corrections.append( + f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + ) + corrections.append( + f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + ) + + # Print the auto-corrections + print(" AUTO-CORRECTION:") + for corr in corrections: + print(f" • {corr}") + auto_corrections.append((decl_name, corrections)) + + invalid_count += 1 + wildcard_count += 1 + + if invalid_count > 0: + print( + f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + ) + + if wildcard_count > 0: + print( + f" ✓ {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + ) + else: + print(f" ✓ All {len(conv_declarations)} configurations valid") + + # Expand Conv declarations (for wildcards) + print("\n Expanding wildcards to valid configurations...") + expanded_conv = [] + for decl in conv_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + expanded = expand_conv_declaration_with_arch_filter(decl, arch) + expanded_conv.extend(expanded) + + # Show what the wildcard expanded to + if len(expanded) > 1: + print( + f" {decl_name}: expanded to {len(expanded)} valid configurations" + ) + for exp in expanded[:3]: + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print( + f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + ) + if len(expanded) > 3: + print(f" ... and {len(expanded) - 3} more") + elif len(expanded) == 1 and is_conv_wildcard_declaration(decl): + exp = expanded[0] + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + + if len(expanded_conv) > len(conv_declarations): + print( + f"\n Total: {len(conv_declarations)} declarations → {len(expanded_conv)} configurations" + ) + + conv_declarations = expanded_conv + + print() + + # Phase 2: Generate kernels + print_phase("Phase 2: Generating kernels...") + + total_generated = 0 + + # Generate GEMM kernels + if gemm_declarations: + print(" GEMM kernels:") + num_gemm = generate_kernels(gemm_declarations, args.gpu_target) + total_generated += num_gemm + print(f" Generated: {num_gemm}") + + # Generate Conv kernels + if conv_declarations: + print(" CONV kernels:") + num_conv = generate_conv_kernels(conv_declarations, args.gpu_target) + total_generated += num_conv + print(f" Generated: {num_conv}") + + print(f" Total kernel files: {total_generated}") + print() + + # Phase 3: Find kernel header + print_phase("Phase 3: Selecting kernel for compilation...") + + kernel_headers = [] + + # Find GEMM kernel header (try each expanded declaration until one matches) + if gemm_declarations: + gemm_header = None + for decl in gemm_declarations: + header = find_kernel_header(decl, args.gpu_target) + if header: + gemm_header = header + break + + if gemm_header: + kernel_headers.append(gemm_header) + print(f" GEMM: {gemm_header.name}") + else: + print_error(" GEMM: No kernel found matching any declaration!") + print_error( + " The kernels declared in DECL_KERNEL_SET must exist or be generatable." + ) + return 1 + + # Find Conv kernel header + if conv_declarations: + first_conv = conv_declarations[0] + conv_header = find_conv_kernel_header(first_conv) + if conv_header: + kernel_headers.append(conv_header) + print(f" CONV: {conv_header.name}") + + if not kernel_headers: + print_error(" No kernel headers found!") + return 1 + + # Use first available header (can be extended to use multiple) + kernel_header = kernel_headers[0] + print() + + # Phase 4: Build dispatcher library + print_phase("Phase 4: Building dispatcher library...") + hipcc = find_hipcc() + + if not build_dispatcher_library(hipcc): + print_error(" Failed to build dispatcher library!") + return 1 + print(" Done") + print() + + # Phase 5: Compile application + print_phase("Phase 5: Compiling application...") + + if not compile_application( + source_file, output_bin, kernel_header, hipcc, args.gpu_target + ): + print_error(" Compilation failed!") + return 1 + + print(f" Output: {output_bin}") + print() + + # Done + print_success("=== Build Complete ===") + print() + print("Run with:") + print(f" {output_bin}") + print() + print("List declared kernels:") + print(f" {output_bin} --list-kernels") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/scripts/example_kernel_builder.py b/dispatcher/scripts/example_kernel_builder.py new file mode 100755 index 0000000000..d3bb619174 --- /dev/null +++ b/dispatcher/scripts/example_kernel_builder.py @@ -0,0 +1,1447 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Build example kernels - generates and compiles kernels for a single example. + +Detects if example is GEMM or Conv based on macro presence, extracts all +configuration parameters, and generates appropriate kernels. +""" + +import argparse +import os +import re +import shutil +import subprocess +import sys +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Dict, List, Tuple + + +def find_hipcc() -> str: + for path in [os.environ.get("HIPCC"), "/opt/rocm/bin/hipcc", shutil.which("hipcc")]: + if path and os.path.isfile(path): + return path + return "hipcc" + + +def find_ar() -> str: + for path in [ + "/opt/rocm/llvm/bin/llvm-ar", + shutil.which("llvm-ar"), + shutil.which("ar"), + ]: + if path and os.path.isfile(path): + return path + return "ar" + + +def extract_balanced_parens(text: str, start_pos: int) -> str: + """Extract content between balanced parentheses.""" + if start_pos >= len(text) or text[start_pos] != "(": + return "" + depth = 0 + for i, c in enumerate(text[start_pos:], start_pos): + if c == "(": + depth += 1 + elif c == ")": + depth -= 1 + if depth == 0: + return text[start_pos + 1 : i] + return "" + + +def parse_conv_declarations(content: str) -> List[Dict]: + """Parse DECL_CONV_KERNEL_SET declarations with all parameters.""" + kernels = [] + + for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content): + body = extract_balanced_parens(content, match.end() - 1) + if not body: + continue + + # Parse each .add() call + for add_match in re.finditer(r"\.add\s*\(", body): + add_body = extract_balanced_parens(body, add_match.end() - 1) + + kernel = {} + + # ConvSig parameters - handle both single dtype and multi-dtype + # Multi-dtype: .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16", "bf16", "fp16") + if m := re.search( + r'\.dtype\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"(?:\s*,\s*"([^"]+)")?\s*\)', + add_body, + ): + kernel["dtype_in"] = m.group(1) + kernel["dtype_wei"] = m.group(2) + kernel["dtype_out"] = m.group(3) + kernel["dtype_acc"] = m.group(4) if m.group(4) else "fp32" + kernel["dtype"] = m.group(1) # Default for codegen + # Single dtype: .dtype("fp16") + elif m := re.search(r'\.dtype\s*\(\s*"([^"]+)"\s*\)', add_body): + kernel["dtype"] = m.group(1) + kernel["dtype_in"] = m.group(1) + kernel["dtype_wei"] = m.group(1) + kernel["dtype_out"] = m.group(1) + kernel["dtype_acc"] = "fp32" + if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body): + kernel["layout"] = m.group(1) + if m := re.search(r'\.conv_type\s*\(\s*"([^"]+)"', add_body): + kernel["conv_type"] = m.group(1) + if m := re.search(r"\.dims\s*\(\s*(\d+)\s*\)", add_body): + kernel["ndim"] = int(m.group(1)) + + # ConvAlgo parameters - tile(G, M, N) where G=batch, M=output, N=reduction + if m := re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["tile_g"] = int(m.group(1)) # batch tile (usually 1) + kernel["tile_m"] = int(m.group(2)) # output channel tile + kernel["tile_n"] = int(m.group(3)) # input channel tile (reduction) + + # wave(M_Warp, N_Warp, K_Warp) - warp distribution + if m := re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["warp_m"] = int(m.group(1)) + kernel["warp_n"] = int(m.group(2)) + kernel["warp_k"] = int(m.group(3)) + + # warp(M_Warp_Tile, N_Warp_Tile, K_Warp_Tile) - warp tile sizes + if m := re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["warp_tile_m"] = int(m.group(1)) + kernel["warp_tile_n"] = int(m.group(2)) + kernel["warp_tile_k"] = int(m.group(3)) + + # vector_sizes(A, B, C) + if m := re.search( + r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["vector_a"] = int(m.group(1)) + kernel["vector_b"] = int(m.group(2)) + kernel["vector_c"] = int(m.group(3)) + + # Single-value parameters + if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body): + kernel["pipeline"] = m.group(1) + if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body): + kernel["scheduler"] = m.group(1) + if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body): + kernel["epilogue"] = m.group(1) + if m := re.search(r"\.block_per_cu\s*\(\s*(\d+)\s*\)", add_body): + kernel["block_per_cu"] = int(m.group(1)) + if m := re.search(r"\.num_wave_groups\s*\(\s*(\d+)\s*\)", add_body): + kernel["num_wave_groups"] = int(m.group(1)) + if m := re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)\s*\)", add_body): + kernel["num_groups_to_merge"] = int(m.group(1)) + if m := re.search( + r"\.double_smem_buffer\s*\(\s*(true|false)\s*\)", add_body, re.I + ): + kernel["double_smem_buffer"] = m.group(1).lower() == "true" + + # Architecture + if m := re.search(r'"(gfx\d+)"', add_body): + kernel["arch"] = m.group(1) + + if kernel.get("dtype"): + # Auto-fill missing parameters with defaults (autocorrect) + kernel = auto_fill_conv_defaults(kernel) + kernels.append(kernel) + + return kernels + + +def auto_fill_conv_defaults(kernel: Dict) -> Dict: + """Auto-fill missing conv parameters with sensible defaults (autofill + autocorrect). + + This implements: + 1. AUTOFILL: Missing parameters are filled with valid defaults (ConvConfigComputeV3) + 2. AUTOCORRECT: Invalid values are corrected to valid ones + """ + # Default tile configuration matching ConvConfigComputeV3 + defaults = { + "tile_g": 1, + "tile_m": 16, + "tile_n": 64, + "warp_m": 1, + "warp_n": 4, + "warp_k": 1, + "warp_tile_m": 16, + "warp_tile_n": 16, + "warp_tile_k": 32, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "vector_a": 4, + "vector_b": 8, + "vector_c": 8, + "block_per_cu": 1, + "num_wave_groups": 1, + "num_groups_to_merge": 1, + "ndim": 2, + "layout": "nhwgc", + "conv_type": "forward", + "arch": "gfx942", + } + + # AUTOFILL: Fill missing parameters with defaults + autofilled = [] + for key, value in defaults.items(): + if key not in kernel or kernel[key] is None or kernel[key] == -1: + kernel[key] = value + autofilled.append(f"{key}={value}") + + if autofilled: + print(f" [AUTOFILL] {', '.join(autofilled)}") + + # AUTOCORRECT: Fix invalid wave configurations for gfx942 + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + current_wave = ( + kernel.get("warp_m", 1), + kernel.get("warp_n", 4), + kernel.get("warp_k", 1), + ) + + if current_wave not in valid_wave_configs: + old = current_wave + kernel["warp_m"] = 1 + kernel["warp_n"] = 4 + kernel["warp_k"] = 1 + print(f" [AUTOCORRECT] wave{old} -> wave(1,4,1) (invalid for gfx942)") + + # AUTOCORRECT: Fix invalid pipeline for backward ops + conv_type = kernel.get("conv_type", "forward") + pipeline = kernel.get("pipeline", "compv3") + + if conv_type in ["bwd_data", "bwd_weight"] and pipeline in ["compv4", "compv5"]: + old_pipeline = pipeline + kernel["pipeline"] = "compv3" + print( + f" [AUTOCORRECT] pipeline {old_pipeline} -> compv3 (invalid for {conv_type})" + ) + + return kernel + + +def expand_conv_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]: + """Expand wildcard parameters to multiple valid configurations. + + When users specify wildcards (-1 or *), this expands them to all + valid configurations for the target architecture. + """ + expanded = [] + + # Valid wave configurations for gfx942 + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + # Valid warp tile configurations for gfx942 fp16 + valid_warp_configs = [(16, 16, 32), (32, 32, 16)] + + # Check if expansion is needed + needs_wave = kernel.get("warp_m") is None or kernel.get("warp_m") == -1 + needs_warp = kernel.get("warp_tile_m") is None or kernel.get("warp_tile_m") == -1 + + if not needs_wave and not needs_warp: + return [kernel] + + # Expand wave configurations + wave_configs = ( + valid_wave_configs + if needs_wave + else [ + (kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1)) + ] + ) + + # Expand warp tile configurations + warp_configs = ( + valid_warp_configs + if needs_warp + else [ + ( + kernel.get("warp_tile_m", 32), + kernel.get("warp_tile_n", 32), + kernel.get("warp_tile_k", 16), + ) + ] + ) + + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_configs: + new_kernel = kernel.copy() + new_kernel["warp_m"] = wm + new_kernel["warp_n"] = wn + new_kernel["warp_k"] = wk + new_kernel["warp_tile_m"] = wtm + new_kernel["warp_tile_n"] = wtn + new_kernel["warp_tile_k"] = wtk + expanded.append(new_kernel) + + return expanded + + +def parse_int_or_wildcard(val: str) -> int: + """Parse integer or return -1 for wildcards. + + Supported wildcard formats: + - ANY_INT: Macro defined as -1 + - -1: Direct numeric wildcard + - "*": String wildcard (also maps to -1 for integer params) + """ + val = val.strip() + if val == "ANY_INT" or val == "-1" or val == "*": + return -1 + return int(val) + + +def parse_gemm_declarations(content: str) -> List[Dict]: + """Parse DECL_KERNEL_SET declarations for GEMM. + + Supports wildcards: + - ANY_INT for numeric params (wave, warp) -> expands to all valid combos + - "*" for string params (pipeline, scheduler) -> expands to valid options + + Each kernel is tagged with its kernel_set name for separate registration. + """ + kernels = [] + + for match in re.finditer(r"DECL_KERNEL_SET\s*\(\s*(\w+)\s*,", content): + kernel_set_name = match.group(1) + body = extract_balanced_parens( + content, match.start() + content[match.start() :].find("(") + ) + if not body: + continue + + for add_match in re.finditer(r"\.add\s*\(", body): + add_body = extract_balanced_parens(body, add_match.end() - 1) + + kernel = {} + + # Signature parameters + if m := re.search(r'\.dtype\s*\(\s*"([^"]+)"', add_body): + kernel["dtype"] = m.group(1) + if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body): + kernel["layout"] = m.group(1) + if m := re.search(r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)', add_body): + kernel["elementwise_op"] = m.group(1) + kernel["num_d_tensors"] = int(m.group(2)) + + # Algorithm parameters - support ANY_INT wildcard + if m := re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["tile_m"] = int(m.group(1)) + kernel["tile_n"] = int(m.group(2)) + kernel["tile_k"] = int(m.group(3)) + + # Wave: support ANY_INT, -1, and "*" as wildcards + if m := re.search( + r"\.wave\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)", + add_body, + ): + kernel["warp_m"] = parse_int_or_wildcard(m.group(1)) + kernel["warp_n"] = parse_int_or_wildcard(m.group(2)) + kernel["warp_k"] = parse_int_or_wildcard(m.group(3)) + + # Warp: support ANY_INT, -1, and "*" as wildcards + if m := re.search( + r"\.warp\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)", + add_body, + ): + kernel["warp_tile_m"] = parse_int_or_wildcard(m.group(1)) + kernel["warp_tile_n"] = parse_int_or_wildcard(m.group(2)) + kernel["warp_tile_k"] = parse_int_or_wildcard(m.group(3)) + + # Pipeline/Scheduler: support "*" wildcard + if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body): + kernel["pipeline"] = m.group(1) + if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body): + kernel["scheduler"] = m.group(1) + if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body): + kernel["epilogue"] = m.group(1) + if m := re.search( + r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)", + add_body, + re.I, + ): + kernel["pad_m"] = m.group(1).lower() == "true" + kernel["pad_n"] = m.group(2).lower() == "true" + kernel["pad_k"] = m.group(3).lower() == "true" + + # Shorthand format: .add("dtype", "layout", M, N, K) + if not kernel.get("dtype"): + if m := re.match( + r'\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)', + add_body, + ): + kernel["dtype"] = m.group(1) + kernel["layout"] = m.group(2) + kernel["tile_m"] = int(m.group(3)) + kernel["tile_n"] = int(m.group(4)) + kernel["tile_k"] = int(m.group(5)) + + if kernel.get("dtype"): + kernel["kernel_set"] = kernel_set_name + kernels.append(kernel) + + # Expand wildcards to multiple kernels + expanded = [] + for kernel in kernels: + expanded.extend(expand_gemm_wildcards(kernel)) + + # Apply autocorrect to each expanded kernel + return [auto_fill_gemm_defaults(k) for k in expanded] + + +def expand_gemm_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]: + """Expand wildcard parameters to multiple valid configurations. + + When users specify ANY_INT (-1) or "*", this expands them to all + valid configurations for the target architecture. + + Note: Block size constraint filters invalid combos: + - (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024 + - For 128x128 tile: only (32,32,k) works (16 warps * 64 = 1024) + - For 64x64 tile: both (16,16,k) and (32,32,k) work + """ + # Valid wave configurations for gfx942 + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + # Valid warp tile configurations for gfx942 fp16 + valid_warp_configs = [(16, 16, 32), (32, 32, 16)] + + # Valid pipelines and schedulers + valid_pipelines = ["compv3"] # compv4 requires special handling + valid_schedulers = ["intrawave"] + + # Check what needs expansion + needs_wave = kernel.get("warp_m") == -1 + needs_warp = kernel.get("warp_tile_m") == -1 + needs_pipeline = kernel.get("pipeline") == "*" + needs_scheduler = kernel.get("scheduler") == "*" + + if not any([needs_wave, needs_warp, needs_pipeline, needs_scheduler]): + return [kernel] + + # Determine configs to iterate + wave_configs = ( + valid_wave_configs + if needs_wave + else [ + (kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1)) + ] + ) + warp_configs = ( + valid_warp_configs + if needs_warp + else [ + ( + kernel.get("warp_tile_m", 32), + kernel.get("warp_tile_n", 32), + kernel.get("warp_tile_k", 16), + ) + ] + ) + pipelines = ( + valid_pipelines if needs_pipeline else [kernel.get("pipeline", "compv3")] + ) + schedulers = ( + valid_schedulers if needs_scheduler else [kernel.get("scheduler", "intrawave")] + ) + + expanded = [] + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_configs: + # Check block size constraint: (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024 + tile_m = kernel.get("tile_m", 128) + tile_n = kernel.get("tile_n", 128) + num_warps = (tile_m // wtm) * (tile_n // wtn) + if num_warps * 64 > 1024: + continue # Skip invalid config + + for pipe in pipelines: + for sched in schedulers: + new_kernel = kernel.copy() + new_kernel["warp_m"] = wm + new_kernel["warp_n"] = wn + new_kernel["warp_k"] = wk + new_kernel["warp_tile_m"] = wtm + new_kernel["warp_tile_n"] = wtn + new_kernel["warp_tile_k"] = wtk + new_kernel["pipeline"] = pipe + new_kernel["scheduler"] = sched + expanded.append(new_kernel) + + if expanded: + print(f" [WILDCARD] Expanded 1 declaration -> {len(expanded)} kernel(s)") + + return expanded if expanded else [kernel] + + +def auto_fill_gemm_defaults(kernel: Dict) -> Dict: + """Auto-fill missing GEMM parameters with sensible defaults (autofill + autocorrect). + + This implements: + 1. AUTOFILL: Missing parameters are filled with valid defaults + 2. AUTOCORRECT: Invalid values are corrected to valid ones (e.g., wave(1,1,1) -> wave(2,2,1)) + """ + defaults = { + "tile_m": 128, + "tile_n": 128, + "tile_k": 64, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "pad_m": False, + "pad_n": False, + "pad_k": False, + "layout": "rcr", + } + + # AUTOFILL: Fill missing parameters with defaults + autofilled = [] + for key, value in defaults.items(): + if key not in kernel or kernel[key] is None or kernel[key] == -1: + kernel[key] = value + autofilled.append(f"{key}={value}") + + if autofilled: + print(f" [AUTOFILL] {', '.join(autofilled)}") + + # AUTOCORRECT: Fix invalid wave configurations for gfx942 + # Valid wave configs: (1,4,1), (2,2,1), (4,1,1) + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + current_wave = ( + kernel.get("warp_m", 2), + kernel.get("warp_n", 2), + kernel.get("warp_k", 1), + ) + + if current_wave not in valid_wave_configs: + # Correct to (2,2,1) which is a balanced default + old = current_wave + kernel["warp_m"] = 2 + kernel["warp_n"] = 2 + kernel["warp_k"] = 1 + print(f" [AUTOCORRECT] wave{old} -> wave(2,2,1) (invalid for gfx942)") + + # AUTOCORRECT: Fix invalid pipeline/scheduler combinations + invalid_combos = [ + ("compv3", "interwave"), + ("compv4", "interwave"), + ] + current_combo = ( + kernel.get("pipeline", "compv3"), + kernel.get("scheduler", "intrawave"), + ) + if current_combo in invalid_combos: + old = current_combo + kernel["scheduler"] = "intrawave" + print( + f" [AUTOCORRECT] {old[0]}/{old[1]} -> {old[0]}/intrawave (invalid combo)" + ) + + # AUTOCORRECT: Fix warp tile to avoid exceeding max block size (1024 threads) + # Block size = (tile_m / warp_tile_m) * (tile_n / warp_tile_n) * 64 + tile_m = kernel.get("tile_m", 128) + tile_n = kernel.get("tile_n", 128) + warp_tile_m = kernel.get("warp_tile_m", 32) + warp_tile_n = kernel.get("warp_tile_n", 32) + + num_warps = (tile_m // warp_tile_m) * (tile_n // warp_tile_n) + block_size = num_warps * 64 # 64 threads per warp + + if block_size > 1024: + # Find valid warp tile that fits + old_warp = (warp_tile_m, warp_tile_n, kernel.get("warp_tile_k", 16)) + + # For large tiles, use larger warp tiles + if tile_m >= 256: + kernel["warp_tile_m"] = 64 + if tile_n >= 256: + kernel["warp_tile_n"] = 64 + + # Recalculate + num_warps = (tile_m // kernel["warp_tile_m"]) * ( + tile_n // kernel["warp_tile_n"] + ) + block_size = num_warps * 64 + + if block_size <= 1024: + new_warp = ( + kernel["warp_tile_m"], + kernel["warp_tile_n"], + kernel["warp_tile_k"], + ) + print( + f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size={block_size})" + ) + else: + # Still too large, try even larger warp tiles + kernel["warp_tile_m"] = tile_m // 4 + kernel["warp_tile_n"] = tile_n // 4 + new_warp = ( + kernel["warp_tile_m"], + kernel["warp_tile_n"], + kernel["warp_tile_k"], + ) + print( + f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size adjusted)" + ) + + return kernel + + +def strip_cpp_strings_and_comments(content: str) -> str: + """Strip C++ string literals and comments that could cause false positives. + + Only strips: + - Comments (// and /* */) - always stripped + - Raw string literals (R"...") - always stripped (can contain anything) + - Regular strings ONLY if they contain problematic patterns like DECL_KERNEL_SET + + Preserves normal string literals like "fp16", "rcr" which are needed for parsing. + """ + result = [] + i = 0 + n = len(content) + + # Patterns that indicate a string is problematic and should be stripped + problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("] + + while i < n: + # Check for raw string literal: R"delimiter(...)delimiter" + # Always strip these as they can contain arbitrary content + if i < n - 1 and content[i] == "R" and content[i + 1] == '"': + # Find the delimiter (between R" and () + j = i + 2 + delimiter_start = j + while j < n and content[j] != "(": + j += 1 + delimiter = content[delimiter_start:j] + # Find the closing )delimiter" + end_marker = ")" + delimiter + '"' + end_pos = content.find(end_marker, j + 1) + if end_pos != -1: + # Replace with spaces to preserve line numbers + span = content[i : end_pos + len(end_marker)] + result.append("".join("\n" if c == "\n" else " " for c in span)) + i = end_pos + len(end_marker) + continue + + # Check for regular string literal - only strip if it contains problematic patterns + if content[i] == '"': + j = i + 1 + while j < n: + if content[j] == "\\" and j + 1 < n: + j += 2 # Skip escaped character + elif content[j] == '"': + j += 1 + break + else: + j += 1 + string_content = content[i:j] + + # Only strip if this string contains problematic patterns + should_strip = any(pat in string_content for pat in problematic_patterns) + if should_strip: + result.append(" " * len(string_content)) + else: + result.append(string_content) + i = j + continue + + # Check for single-line comment - always strip + if i < n - 1 and content[i : i + 2] == "//": + j = i + while j < n and content[j] != "\n": + j += 1 + result.append(" " * (j - i)) + i = j + continue + + # Check for multi-line comment - always strip + if i < n - 1 and content[i : i + 2] == "/*": + end_pos = content.find("*/", i + 2) + if end_pos != -1: + span = content[i : end_pos + 2] + # Preserve newlines in multi-line comments + result.append("".join("\n" if c == "\n" else " " for c in span)) + i = end_pos + 2 + continue + + result.append(content[i]) + i += 1 + + return "".join(result) + + +def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]: + """Detect example type and parse kernel declarations. + + Properly strips string literals and comments before parsing to avoid + picking up declarations inside strings or commented-out code. + """ + content = source_path.read_text() + content = strip_cpp_strings_and_comments(content) + + if "DECL_CONV_KERNEL_SET" in content: + return "conv", parse_conv_declarations(content) + elif "DECL_KERNEL_SET" in content: + return "gemm", parse_gemm_declarations(content) + return "unknown", [] + + +def generate_gemm_registration( + kernel_headers: List[Path], example_name: str, kernels: List[Dict] = None +) -> str: + """Generate GEMM kernel registration code for the dispatcher registry. + + Uses GeneratedKernelInstance to wrap the generated kernels + and provide the KernelInstance interface for the Dispatcher. + + If kernels list is provided with kernel_set info, generates separate + registration functions per kernel set. + """ + if not kernel_headers: + return " // No kernels to register" + + # Build mapping from kernel config pattern to kernel set + kernel_to_set = {} + kernel_sets = set() + if kernels: + for k in kernels: + tile_m = k.get("tile_m", 128) + tile_n = k.get("tile_n", 128) + tile_k = k.get("tile_k", 64) + warp_m = k.get("warp_m", 2) + warp_n = k.get("warp_n", 2) + warp_k = k.get("warp_k", 1) + warp_tile_m = k.get("warp_tile_m", 32) + warp_tile_n = k.get("warp_tile_n", 32) + warp_tile_k = k.get("warp_tile_k", 16) + + # Pattern that appears in kernel filename + key_pattern = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + kernel_set = k.get("kernel_set", "default") + kernel_to_set[key_pattern] = kernel_set + kernel_sets.add(kernel_set) + + def generate_registration_block(h: Path) -> str: + """Generate registration code for a single kernel.""" + kernel_name = h.stem + ns = f"ns_{kernel_name}" + + # Parse pipeline, scheduler, and layout from kernel name + # Format: gemm_fp16_rcr_compv3_cshuffle_intrawave_... + parts = kernel_name.split("_") + pipeline = "CompV3" + scheduler = "Intrawave" + epilogue = "CShuffle" + datatype = "FP16" + layout_a = "RowMajor" + layout_b = "ColMajor" + layout_c = "RowMajor" + + # Parse datatype (e.g., fp16, bf16, fp32) + dtype_map = { + "fp16": "FP16", + "bf16": "BF16", + "fp32": "FP32", + "fp64": "FP64", + "int8": "INT8", + } + + # Parse layout from 3-char codes (e.g., rcr, rrr, rrc, ccc) + # r = RowMajor, c = ColMajor + layout_map = {"r": "RowMajor", "c": "ColMajor"} + + # Find pipeline, epilogue, scheduler in the name parts + pipeline_map = { + "mem": "Mem", + "compv1": "CompV1", + "compv2": "CompV2", + "compv3": "CompV3", + "compv4": "CompV4", + "compv5": "CompV5", + "preshufflev1": "PreShuffleV1", + "preshufflev2": "PreShuffleV2", + } + scheduler_map = { + "intrawave": "Intrawave", + "interwave": "Interwave", + "auto": "Auto", + } + epilogue_map = {"default": "Default", "cshuffle": "CShuffle", "none": "None"} + + for part in parts: + if part in pipeline_map: + pipeline = pipeline_map[part] + if part in scheduler_map: + scheduler = scheduler_map[part] + if part in epilogue_map: + epilogue = epilogue_map[part] + if part in dtype_map: + datatype = dtype_map[part] + # Parse 3-char layout codes (e.g., rcr, rrr) + if len(part) == 3 and all(c in "rc" for c in part): + layout_a = layout_map[part[0]] + layout_b = layout_map[part[1]] + layout_c = layout_map[part[2]] + + block = [] + block.append(f" // Register kernel: {kernel_name}") + block.append(" {") + block.append(f" using SelectedKernel = {ns}::SelectedKernel;") + block.append(" ck_tile::dispatcher::KernelKey key;") + block.append( + f" key.signature.dtype_a = ck_tile::dispatcher::DataType::{datatype};" + ) + block.append( + f" key.signature.dtype_b = ck_tile::dispatcher::DataType::{datatype};" + ) + block.append( + f" key.signature.dtype_c = ck_tile::dispatcher::DataType::{datatype};" + ) + block.append( + " key.signature.dtype_acc = ck_tile::dispatcher::DataType::FP32;" + ) + block.append( + f" key.signature.layout_a = ck_tile::dispatcher::LayoutTag::{layout_a};" + ) + block.append( + f" key.signature.layout_b = ck_tile::dispatcher::LayoutTag::{layout_b};" + ) + block.append( + f" key.signature.layout_c = ck_tile::dispatcher::LayoutTag::{layout_c};" + ) + block.append(" key.algorithm.tile_shape.m = SelectedKernel::TileM;") + block.append(" key.algorithm.tile_shape.n = SelectedKernel::TileN;") + block.append(" key.algorithm.tile_shape.k = SelectedKernel::TileK;") + block.append( + " key.algorithm.wave_shape.m = SelectedKernel::WarpPerBlock_M;" + ) + block.append( + " key.algorithm.wave_shape.n = SelectedKernel::WarpPerBlock_N;" + ) + block.append( + " key.algorithm.wave_shape.k = SelectedKernel::WarpPerBlock_K;" + ) + block.append( + " key.algorithm.warp_tile_shape.m = SelectedKernel::WarpTileM;" + ) + block.append( + " key.algorithm.warp_tile_shape.n = SelectedKernel::WarpTileN;" + ) + block.append( + " key.algorithm.warp_tile_shape.k = SelectedKernel::WarpTileK;" + ) + block.append( + " key.algorithm.block_size = SelectedKernel::BlockSize;" + ) + block.append( + f" key.algorithm.pipeline = ck_tile::dispatcher::Pipeline::{pipeline};" + ) + block.append( + f" key.algorithm.scheduler = ck_tile::dispatcher::Scheduler::{scheduler};" + ) + block.append( + f" key.algorithm.epilogue = ck_tile::dispatcher::Epilogue::{epilogue};" + ) + block.append(" key.gfx_arch = arch;") + block.append( + f' auto instance = std::make_shared>(key, "{kernel_name}");' + ) + block.append(" registry.register_kernel(instance);") + block.append(" }") + return "\n".join(block) + + def find_kernel_set(header: Path) -> str: + """Find which kernel set a header belongs to.""" + name = header.stem + for pattern, kset in kernel_to_set.items(): + if pattern in name: + return kset + return "default" + + # Group kernels by set + kernels_by_set = {} + for h in kernel_headers: + kset = find_kernel_set(h) + if kset not in kernels_by_set: + kernels_by_set[kset] = [] + kernels_by_set[kset].append(h) + + # If only one set or no set info, use simple registration + if len(kernels_by_set) <= 1: + lines = [" (void)arch;", ""] + for h in kernel_headers: + lines.append(generate_registration_block(h)) + return "\n".join(lines) + + # Multiple sets - generate registration for all, plus store per-set info + lines = [" // Register ALL kernels from all sets", " (void)arch;", ""] + for h in kernel_headers: + lines.append(generate_registration_block(h)) + + # Store per-set mapping for separate function generation + global _kernels_by_set_cache + _kernels_by_set_cache = (kernels_by_set, generate_registration_block) + + return "\n".join(lines) + + +# Global cache for per-set kernel info +_kernels_by_set_cache = None + + +def generate_per_set_functions(source_stem: str) -> str: + """Generate separate registration functions for each kernel set. + + Generates: + 1. Per-set functions: register_(registry, arch) + 2. String-based dispatcher: register_kernel_set("set_name", registry, arch) + 3. get_kernel_set_names() to list available sets + """ + global _kernels_by_set_cache + if not _kernels_by_set_cache: + return "" + + kernels_by_set, gen_block = _kernels_by_set_cache + _kernels_by_set_cache = None # Clear cache + + lines = [] + set_names = [] + + # Generate per-set functions + for set_name, headers in kernels_by_set.items(): + safe_name = set_name.replace("-", "_") + set_names.append((set_name, safe_name)) + lines.append( + f"inline void register_{safe_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{" + ) + lines.append(" (void)arch;") + for h in headers: + lines.append(gen_block(h)) + lines.append("}") + lines.append("") + + # Generate string-based dispatcher (only if multiple sets) + if len(set_names) > 0: + lines.append("// Dynamic registration by kernel set name") + lines.append( + "inline bool register_kernel_set(const std::string& set_name, ck_tile::dispatcher::Registry& registry, const std::string& arch) {" + ) + for set_name, safe_name in set_names: + lines.append( + f' if (set_name == "{set_name}") {{ register_{safe_name}(registry, arch); return true; }}' + ) + lines.append(" return false; // Unknown set name") + lines.append("}") + lines.append("") + + # Generate helper to list available set names + lines.append("// Get list of available kernel set names") + lines.append("inline std::vector get_kernel_set_names() {") + names_str = ", ".join(f'"{name}"' for name, _ in set_names) + lines.append(f" return {{{names_str}}};") + lines.append("}") + lines.append("") + + return "\n".join(lines) + + +def generate_conv_registration( + kernel_headers: List[Path], example_name: str, kernels: List[Dict] +) -> str: + """Generate Conv kernel registration code for the dispatcher registry.""" + if not kernel_headers: + return " // No kernels to register" + + lines = [] + lines.append( + " (void)registry; (void)arch; // Conv uses direct launcher pattern for now" + ) + + # For conv, we provide direct access to kernel launchers + for i, h in enumerate(kernel_headers): + kernel_name = h.stem + lines.append(f" // Kernel {i + 1}: {kernel_name}") + + return "\n".join(lines) + + +def generate_conv_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path +) -> bool: + """Generate Conv kernels for ALL declarations using unified codegen.""" + if not kernels: + return False + + variant_map = { + "forward": "forward", + "bwd_data": "bwd_data", + "backward_data": "bwd_data", + "bwd_weight": "bwd_weight", + "backward_weight": "bwd_weight", + } + + success_count = 0 + + # Generate a kernel for EACH declaration + for idx, k in enumerate(kernels): + variant = variant_map.get(k.get("conv_type", "forward"), "forward") + + cmd = [ + sys.executable, + str(codegen_dir / "unified_conv_codegen.py"), + "--datatype", + k.get("dtype", "fp16"), + "--variant", + variant, + "--ndim", + str(k.get("ndim", 2)), + "--output", + str(output_dir), + ] + + # Add optional parameters if specified + if k.get("tile_m"): + cmd.extend(["--tile-m", str(k["tile_m"])]) + if k.get("tile_n"): + cmd.extend(["--tile-n", str(k["tile_n"])]) + if k.get("warp_m"): + cmd.extend(["--warp-m", str(k["warp_m"])]) + if k.get("warp_n"): + cmd.extend(["--warp-n", str(k["warp_n"])]) + if k.get("warp_k"): + cmd.extend(["--warp-k", str(k["warp_k"])]) + if k.get("warp_tile_m"): + cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])]) + if k.get("warp_tile_n"): + cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])]) + if k.get("warp_tile_k"): + cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])]) + if k.get("pipeline"): + cmd.extend(["--pipeline", k["pipeline"]]) + if k.get("scheduler"): + cmd.extend(["--scheduler", k["scheduler"]]) + if k.get("epilogue"): + cmd.extend(["--epilogue", k["epilogue"]]) + if k.get("vector_a"): + cmd.extend(["--vector-a", str(k["vector_a"])]) + if k.get("vector_b"): + cmd.extend(["--vector-b", str(k["vector_b"])]) + if k.get("vector_c"): + cmd.extend(["--vector-c", str(k["vector_c"])]) + if k.get("block_per_cu"): + cmd.extend(["--block-per-cu", str(k["block_per_cu"])]) + if k.get("num_wave_groups"): + cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])]) + if k.get("num_groups_to_merge"): + cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])]) + if k.get("double_smem_buffer") is not None: + cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()]) + if k.get("tile_k"): + cmd.extend(["--tile-k", str(k["tile_k"])]) + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=str(codegen_dir) + ) + if result.returncode != 0: + print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") + else: + success_count += 1 + + return success_count > 0 + + +def generate_gemm_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path +) -> bool: + """Generate GEMM kernels for ALL declarations using unified codegen.""" + import json + + if not kernels: + return False + + success_count = 0 + + # Generate a kernel for EACH declaration + for idx, k in enumerate(kernels): + variant = "multi_d" if k.get("elementwise_op") else "standard" + + # Build tile config JSON for this specific kernel + tile_config = { + "tile_m": [k.get("tile_m", 128)], + "tile_n": [k.get("tile_n", 128)], + "tile_k": [k.get("tile_k", 32)], + "warp_m": [k.get("warp_m", 2)], + "warp_n": [k.get("warp_n", 2)], + "warp_k": [k.get("warp_k", 1)], + "warp_tile_m": [k.get("warp_tile_m", 32)], + "warp_tile_n": [k.get("warp_tile_n", 32)], + "warp_tile_k": [k.get("warp_tile_k", 16)], + } + + trait_config = { + "pipeline": [k.get("pipeline", "compv3")], + "epilogue": [k.get("epilogue", "cshuffle")], + "scheduler": [k.get("scheduler", "intrawave")], + "pad_m": [k.get("pad_m", False)], + "pad_n": [k.get("pad_n", False)], + "pad_k": [k.get("pad_k", False)], + "persistent": [False], + } + + config_json = json.dumps( + {"tile_config": tile_config, "trait_config": trait_config} + ) + + cmd = [ + sys.executable, + str(codegen_dir / "unified_gemm_codegen.py"), + "--datatype", + k.get("dtype", "fp16"), + "--layout", + k.get("layout", "rcr"), + "--variants", + variant, + "--output", + str(output_dir), + "--tile-config-json", + config_json, + ] + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=str(codegen_dir) + ) + if result.returncode != 0: + print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") + else: + success_count += 1 + + return success_count > 0 + + +def compile_kernel(args: Tuple) -> Tuple[str, bool, str]: + """Compile a single kernel to object file.""" + kernel_hpp, output_dir, include_dirs, hipcc, gpu_target, idx, total = args + kernel_name = kernel_hpp.stem + + wrapper_cpp = output_dir / f"{kernel_name}.cpp" + wrapper_cpp.write_text( + f'#include "{kernel_hpp.name}"\nnamespace {{ volatile bool _k{idx} = true; }}\n' + ) + + obj_file = output_dir / f"{kernel_name}.o" + + cmd = [ + hipcc, + "-c", + "-fPIC", + "-std=c++17", + "-O3", + f"--offload-arch={gpu_target}", + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + "--offload-compress", + ] + + for inc_dir in include_dirs: + cmd.extend(["-I", str(inc_dir)]) + cmd.extend(["-I", str(kernel_hpp.parent)]) + cmd.extend(["-o", str(obj_file), str(wrapper_cpp)]) + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + return (kernel_name, False, result.stderr[:500]) + return (kernel_name, True, str(obj_file)) + + +def main(): + parser = argparse.ArgumentParser(description="Build example kernels") + parser.add_argument("source", type=Path, help="C++ source file") + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--include-dirs", type=str, required=True) + parser.add_argument("--gpu-target", type=str, default="gfx942") + parser.add_argument("--jobs", type=int, default=os.cpu_count()) + parser.add_argument( + "--target-name", type=str, help="CMake target name (for library naming)" + ) + args = parser.parse_args() + + script_dir = Path(__file__).parent + codegen_dir = script_dir.parent / "codegen" + source_stem = args.source.stem # e.g., "01_basic_gemm" + target_name = args.target_name or source_stem # e.g., "gemm_01_basic" from CMake + + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Detect and parse + example_type, kernels = detect_and_parse(args.source) + + if example_type == "conv": + k = kernels[0] if kernels else {} + variant = k.get("conv_type", "forward") + print( + f"[{target_name}] Conv {k.get('dtype', 'fp16')} {variant} {k.get('ndim', 2)}D ({len(kernels)} declarations)" + ) + elif example_type == "gemm": + k = kernels[0] if kernels else {} + print( + f"[{target_name}] GEMM {k.get('dtype', 'fp16')} {k.get('layout', 'rcr')} ({len(kernels)} declarations)" + ) + else: + print(f"[{target_name}] No kernel declarations - creating empty library") + lib_path = args.output_dir / f"lib{target_name}_kernels.a" + subprocess.run([find_ar(), "rcs", str(lib_path)], check=True) + header = args.output_dir / f"{source_stem}_kernels.hpp" + header.write_text(f"// No kernels for {target_name}\n#pragma once\n") + return 0 + + # Generate kernels + print(f"[{target_name}] Generating kernels...") + if example_type == "conv": + success = generate_conv_kernels(kernels, args.output_dir, codegen_dir) + else: + success = generate_gemm_kernels(kernels, args.output_dir, codegen_dir) + + if not success: + print(f"[{target_name}] Kernel generation failed!") + return 1 + + # Find generated headers + if example_type == "gemm": + kernel_headers = list(args.output_dir.glob("gemm_*.hpp")) + else: + k = kernels[0] if kernels else {} + variant = k.get("conv_type", "forward") + prefix_map = { + "forward": "conv_fwd", + "bwd_data": "conv_bwdd", + "bwd_weight": "conv_bwdw", + } + prefix = prefix_map.get(variant, "conv_fwd") + kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp")) + + if not kernel_headers: + print(f"[{target_name}] No kernel headers generated!") + return 1 + + print(f"[{target_name}] Compiling {len(kernel_headers)} kernels...") + + include_dirs = [Path(p.strip()) for p in args.include_dirs.split(",")] + hipcc = find_hipcc() + + work = [ + ( + h, + args.output_dir, + include_dirs, + hipcc, + args.gpu_target, + i + 1, + len(kernel_headers), + ) + for i, h in enumerate(kernel_headers) + ] + + obj_files = [] + failed = [] + + with ProcessPoolExecutor(max_workers=args.jobs) as executor: + futures = {executor.submit(compile_kernel, w): w[0].name for w in work} + for future in as_completed(futures): + name, ok, result = future.result() + if ok: + obj_files.append(result) + else: + failed.append((name, result)) + print(f"[{target_name}] FAILED: {name}") + + if failed: + print(f"[{target_name}] {len(failed)} kernels failed") + for name, err in failed[:3]: + print(f" {name}: {err[:200]}") + return 1 + + # Create static library (use target_name for CMake compatibility) + lib_path = args.output_dir / f"lib{target_name}_kernels.a" + subprocess.run([find_ar(), "rcs", str(lib_path)] + obj_files, check=True) + + # Generate registration header (use source_stem for header name to match CMake's EXAMPLE_STEM) + header_path = args.output_dir / f"{source_stem}_kernels.hpp" + + # Build includes + includes = "\n".join(f'#include "{h.name}"' for h in kernel_headers) + + # Build kernel registration entries + # Function name uses source_stem (e.g., register_01_basic_gemm_kernels) + func_name = f"register_{source_stem}_kernels" + + # Generate registration code based on example type + if example_type == "gemm": + register_body = generate_gemm_registration(kernel_headers, target_name, kernels) + else: + register_body = generate_conv_registration(kernel_headers, target_name, kernels) + + # Generate appropriate header based on example type + if example_type == "conv" and kernel_headers: + launcher_aliases = [] + + # Helper to find kernel by dtype and type + def find_kernel_by_dtype_type(headers, dtype, conv_type_marker): + """Find kernel matching dtype and conv type, prioritize fp16.""" + matching = [h for h in headers if conv_type_marker in h.stem] + # Prefer fp16 over bf16 for default launchers + fp16_kernels = [h for h in matching if f"_{dtype}_" in h.stem] + return ( + fp16_kernels[0] if fp16_kernels else (matching[0] if matching else None) + ) + + # Check what conv types are in the declarations + has_fwd = any("forward" in k.get("conv_type", "forward") for k in kernels) + has_bwd_data = any("bwd_data" in k.get("conv_type", "") for k in kernels) + has_bwd_weight = any("bwd_weight" in k.get("conv_type", "") for k in kernels) + + # Export dtype-specific launcher aliases for each available dtype + for dtype in ["fp16", "bf16", "fp32"]: + dtype_fwd_kernels = [ + h + for h in kernel_headers + if "_fwd_" in h.stem and f"_{dtype}_" in h.stem + ] + if dtype_fwd_kernels: + k = dtype_fwd_kernels[0] + ns = f"ns_{k.stem}" + dtype_upper = dtype.upper() + launcher_aliases.append( + f"using {dtype_upper}FwdKernelLauncher = {ns}::{k.stem}_Launcher;" + ) + + # Export generic launcher aliases (prioritize fp16) + if has_fwd: + fwd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_fwd_") + if fwd_kernel: + fwd_ns = f"ns_{fwd_kernel.stem}" + launcher_aliases.append( + f"using FwdKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;" + ) + launcher_aliases.append( + f"using FirstKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;" + ) + + if has_bwd_data: + bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_") + if bwdd_kernel: + bwdd_ns = f"ns_{bwdd_kernel.stem}" + launcher_aliases.append( + f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + ) + if not has_fwd: # If no fwd, use bwd_data as first + launcher_aliases.append( + f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + ) + + if has_bwd_weight: + bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_") + if bwdw_kernel: + bwdw_ns = f"ns_{bwdw_kernel.stem}" + launcher_aliases.append( + f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + ) + if ( + not has_fwd and not has_bwd_data + ): # If no fwd or bwdd, use bwdw as first + launcher_aliases.append( + f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + ) + + launcher_section = "\n".join(launcher_aliases) + + header_content = f"""// Auto-generated for {target_name} +#pragma once + +{includes} + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" + +namespace generated {{ + +// Kernel launchers for direct use +{launcher_section} + +// Registration function +inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{ +{register_body} +}} + +}} // namespace generated + +// Generic registration - avoids hardcoding the example name in user code +// Safe for single-example executables (typical use case) +#ifndef REGISTER_GENERATED_KERNELS +#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch) +#endif +""" + else: + # GEMM: Generate per-set functions if multiple kernel sets declared + per_set_funcs = generate_per_set_functions(source_stem) + + header_content = f"""// Auto-generated for {target_name} +#pragma once + +{includes} + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp" + +namespace generated {{ + +// Register ALL kernels from all declared sets +inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{ +{register_body} +}} + +{per_set_funcs} +}} // namespace generated + +// Generic registration - avoids hardcoding the example name in user code +// Safe for single-example executables (typical use case) +#ifndef REGISTER_GENERATED_KERNELS +#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch) +#endif + +// Register a specific kernel set by name (for multi-registry patterns) +// Usage: REGISTER_KERNEL_SET("compute_bound_set", registry, arch) +#ifndef REGISTER_KERNEL_SET +#define REGISTER_KERNEL_SET(set_name, registry, arch) generated::register_kernel_set(set_name, registry, arch) +#endif +""" + header_path.write_text(header_content) + + print(f"[{target_name}] ✓ {len(obj_files)} kernels compiled") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/scripts/parallel_kernel_builder.py b/dispatcher/scripts/parallel_kernel_builder.py new file mode 100755 index 0000000000..911ea61bd7 --- /dev/null +++ b/dispatcher/scripts/parallel_kernel_builder.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Build kernels in parallel - one translation unit per kernel. + +This script is called at make time (not cmake time) to avoid slow cmake configuration. +""" + +import argparse +import os +import subprocess +import sys +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor, as_completed + + +def find_hipcc(): + """Find hipcc compiler.""" + candidates = [ + os.environ.get("HIPCC"), + "/opt/rocm/bin/hipcc", + shutil.which("hipcc") if shutil else None, + ] + for path in candidates: + if path and os.path.isfile(path): + return path + return "hipcc" # Assume in PATH + + +def compile_kernel(args): + """Compile a single kernel.""" + kernel_hpp, output_dir, include_dirs, hipcc = args + kernel_name = kernel_hpp.stem + + # Create wrapper .cpp + wrapper_cpp = output_dir / f"{kernel_name}.cpp" + wrapper_cpp.write_text(f'''// Auto-generated wrapper +#include "{kernel_hpp.name}" +namespace {{ volatile bool _k = true; }} +''') + + # Compile to object + obj_file = output_dir / f"{kernel_name}.o" + + cmd = [ + hipcc, + "-c", + "-fPIC", + "-std=c++17", + "-O3", + "--offload-arch=gfx942", + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + "--offload-compress", + ] + + for inc_dir in include_dirs: + cmd.extend(["-I", str(inc_dir)]) + cmd.extend(["-I", str(kernel_hpp.parent)]) + + cmd.extend(["-o", str(obj_file), str(wrapper_cpp)]) + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + return (kernel_name, False, result.stderr) + return (kernel_name, True, str(obj_file)) + + +def main(): + parser = argparse.ArgumentParser(description="Build kernels in parallel") + parser.add_argument("--kernel-dir", type=Path, required=True) + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--include-dirs", type=str, required=True) + parser.add_argument("--jobs", type=int, default=os.cpu_count()) + args = parser.parse_args() + + # Find kernel headers + kernel_headers = list(args.kernel_dir.glob("gemm_*.hpp")) + list( + args.kernel_dir.glob("conv_*.hpp") + ) + + if not kernel_headers: + print("No kernels found to build") + return 0 + + print(f"Building {len(kernel_headers)} kernels with {args.jobs} parallel jobs...") + + include_dirs = [Path(p.strip()) for p in args.include_dirs.split(",")] + hipcc = find_hipcc() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Prepare work items + work = [(h, args.output_dir, include_dirs, hipcc) for h in kernel_headers] + + # Compile in parallel + obj_files = [] + failed = [] + + with ProcessPoolExecutor(max_workers=args.jobs) as executor: + futures = {executor.submit(compile_kernel, w): w[0].name for w in work} + + for i, future in enumerate(as_completed(futures), 1): + name, success, result = future.result() + if success: + obj_files.append(result) + print(f"[{i}/{len(kernel_headers)}] Built: {name}") + else: + failed.append((name, result)) + print(f"[{i}/{len(kernel_headers)}] FAILED: {name}") + + if failed: + print(f"\n{len(failed)} kernels failed to compile:") + for name, err in failed[:5]: + print(f" {name}: {err[:100]}") + return 1 + + # Link into shared library + print(f"\nLinking {len(obj_files)} objects into libdispatcher_kernels.so...") + lib_path = args.output_dir / "libdispatcher_kernels.so" + + link_cmd = [hipcc, "-shared", "-fPIC", "-o", str(lib_path)] + obj_files + result = subprocess.run(link_cmd, capture_output=True, text=True) + + if result.returncode != 0: + print(f"Linking failed: {result.stderr}") + return 1 + + print(f"✓ Built: {lib_path}") + return 0 + + +if __name__ == "__main__": + import shutil + + sys.exit(main()) diff --git a/dispatcher/scripts/stress_test_autocorrect.py b/dispatcher/scripts/stress_test_autocorrect.py new file mode 100644 index 0000000000..13e92abffa --- /dev/null +++ b/dispatcher/scripts/stress_test_autocorrect.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Stress Test for Auto-Correction and Codegen + +This script tests the robustness of: +1. GEMM auto-correction (Python) +2. Conv auto-correction (Python) +3. C++ kernel declaration validation and wildcard expansion +4. Architecture filtering + +Usage: + python3 scripts/stress_test_autocorrect.py [--arch gfx942] [--samples 50] [--verbose] +""" + +import argparse +import random +import sys +from pathlib import Path + +# Add paths for imports +dispatcher_root = Path(__file__).parent.parent +sys.path.insert(0, str(dispatcher_root / "python")) +sys.path.insert(0, str(dispatcher_root / "codegen")) +sys.path.insert(0, str(dispatcher_root / "scripts")) + +from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402 + +# Import validation/expansion functions from compile scripts +from compile_gemm_examples import ( # noqa: E402 + validate_kernel_config, + expand_declaration_with_arch_filter, +) +from compile_conv_examples import ( # noqa: E402 + validate_conv_kernel_config, + expand_conv_declaration_with_arch_filter, +) + + +# ============================================================================= +# TEST PARAMETERS +# ============================================================================= + +# Valid dtypes +DTYPES = ["fp16", "bf16", "fp32", "fp8", "bf8", "int8"] + +# Valid layouts +LAYOUTS = ["rcr", "rrr", "crr", "ccr"] + +# Tile sizes (some valid, some invalid) +TILE_SIZES = [ + (32, 32, 16), + (64, 64, 32), + (128, 128, 32), + (256, 256, 64), + (128, 256, 32), + (256, 128, 32), + # Invalid sizes to test auto-correction + (100, 100, 50), + (17, 17, 17), + (512, 512, 128), +] + +# Wave configs (some valid, some invalid) +WAVE_CONFIGS = [ + (1, 1, 1), + (1, 2, 1), + (2, 1, 1), + (2, 2, 1), + (1, 4, 1), + (4, 1, 1), + (2, 4, 1), + (4, 2, 1), + # Invalid configs to test auto-correction + (3, 3, 1), + (5, 5, 1), + (1, 1, 2), +] + +# Warp tile sizes (some valid, some invalid) +WARP_TILES = [ + (16, 16, 16), + (16, 16, 32), + (32, 32, 8), + (32, 32, 16), + # Invalid tiles to test auto-correction + (48, 48, 24), + (64, 64, 32), +] + +# Pipelines and schedulers +PIPELINES = ["compv3", "compv4", "flatmma", "invalid_pipeline"] +SCHEDULERS = ["intrawave", "interwave", "invalid_scheduler"] + +# Architectures +ARCHS = ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1200", "gfx1201"] + + +# ============================================================================= +# TEST FUNCTIONS +# ============================================================================= + + +def generate_random_gemm_config(): + """Generate a random GEMM configuration (may be invalid).""" + dtype = random.choice(DTYPES) + layout = random.choice(LAYOUTS) + tile = random.choice(TILE_SIZES) + wave = random.choice(WAVE_CONFIGS) + warp = random.choice(WARP_TILES) + pipeline = random.choice(PIPELINES) + scheduler = random.choice(SCHEDULERS) + arch = random.choice(ARCHS) + + return { + "name": f"test_{dtype}_{layout}_{tile[0]}x{tile[1]}x{tile[2]}", + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "dtype_acc": "fp32", + "layout": layout, + "tile_m": tile[0], + "tile_n": tile[1], + "tile_k": tile[2], + "wave_m": wave[0], + "wave_n": wave[1], + "wave_k": wave[2], + "warp_m": warp[0], + "warp_n": warp[1], + "warp_k": warp[2], + "pipeline": pipeline, + "scheduler": scheduler, + "arch": arch, + } + + +def generate_random_conv_config(): + """Generate a random Conv configuration (may be invalid).""" + dtype = random.choice(["fp16", "bf16"]) + tile_k = random.choice([64, 128, 256]) + tile_c = random.choice([64, 128, 256]) + wave = random.choice(WAVE_CONFIGS) + warp = random.choice(WARP_TILES) + pipeline = random.choice(["compv3", "compv4"]) + scheduler = random.choice(["intrawave"]) + arch = random.choice(ARCHS) + + return { + "name": f"test_conv_{dtype}_{tile_k}x{tile_c}", + "dtype": dtype, + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": tile_k, + "tile_c": tile_c, + "wave_m": wave[0], + "wave_n": wave[1], + "wave_k": wave[2], + "warp_m": warp[0], + "warp_n": warp[1], + "warp_k": warp[2], + "pipeline": pipeline, + "scheduler": scheduler, + "arch": arch, + } + + +def test_gemm_validation(config, verbose=False): + """Test GEMM validation and auto-correction.""" + arch = config.get("arch", "gfx942") + is_valid, error_msg = validate_kernel_config(config, arch) + + result = { + "config": config, + "is_valid": is_valid, + "error_msg": error_msg, + "expanded": [], + "auto_corrected": None, + } + + if not is_valid: + # Try wildcard expansion + wildcard_config = config.copy() + wildcard_config["wave_m"] = -1 + wildcard_config["wave_n"] = -1 + wildcard_config["warp_m"] = -1 + wildcard_config["warp_n"] = -1 + wildcard_config["pipeline"] = "*" + wildcard_config["scheduler"] = "*" + + expanded = expand_declaration_with_arch_filter(wildcard_config, arch) + result["expanded"] = expanded + + if verbose: + print(f"\n Config: {config['name']}") + print(f" Valid: {is_valid}") + if not is_valid: + print(f" Error: {error_msg[:80]}...") + print(f" Expanded to: {len(result['expanded'])} configurations") + + return result + + +def test_python_autocorrect(verbose=False): + """Test Python auto-correction for GEMM KernelConfig.""" + print("\n" + "=" * 70) + print(" PYTHON AUTO-CORRECTION TEST (GEMM KernelConfig)") + print("=" * 70) + + test_cases = [ + # Valid config + { + "name": "valid_fp16", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "gfx_arch": "gfx942", + }, + # Invalid wave config + { + "name": "invalid_wave", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 1, + "wave_n": 1, + "wave_k": 1, # Invalid for gfx942 + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "gfx_arch": "gfx942", + }, + # Invalid scheduler + { + "name": "invalid_scheduler", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "interwave", # May not be valid for all archs + "gfx_arch": "gfx942", + }, + ] + + results = {"passed": 0, "failed": 0, "details": []} + + for tc in test_cases: + try: + config = KernelConfig() + config.dtype_a = tc["dtype_a"] + config.dtype_b = tc["dtype_b"] + config.dtype_c = tc["dtype_c"] + config.dtype_acc = tc["dtype_acc"] + config.tile_m = tc["tile_m"] + config.tile_n = tc["tile_n"] + config.tile_k = tc["tile_k"] + config.wave_m = tc["wave_m"] + config.wave_n = tc["wave_n"] + config.wave_k = tc["wave_k"] + config.warp_m = tc["warp_m"] + config.warp_n = tc["warp_n"] + config.warp_k = tc["warp_k"] + config.pipeline = tc["pipeline"] + config.scheduler = tc["scheduler"] + config.gfx_arch = tc["gfx_arch"] + + corrected, was_modified, corrections = auto_correct_kernel_config( + config, verbose=verbose + ) + + results["passed"] += 1 + results["details"].append( + { + "name": tc["name"], + "status": "PASS", + "was_modified": was_modified, + "corrections": corrections, + } + ) + + if verbose: + print(f"\n {tc['name']}: PASS") + if was_modified: + print(f" Modified: {len(corrections)} correction(s)") + for c in corrections: + print(f" • {c}") + + except Exception as e: + results["failed"] += 1 + results["details"].append( + {"name": tc["name"], "status": "FAIL", "error": str(e)} + ) + if verbose: + print(f"\n {tc['name']}: FAIL - {e}") + + print(f"\n Summary: {results['passed']} passed, {results['failed']} failed") + return results + + +def run_stress_test(arch, num_samples, verbose): + """Run the full stress test.""" + print("\n" + "=" * 70) + print(" DISPATCHER AUTO-CORRECTION & CODEGEN STRESS TEST") + print("=" * 70) + print(f" Target Architecture: {arch}") + print(f" Number of Samples: {num_samples}") + print("=" * 70) + + # Test 1: GEMM Validation + print("\n" + "-" * 70) + print(" TEST 1: GEMM Validation & Wildcard Expansion") + print("-" * 70) + + gemm_results = {"valid": 0, "invalid": 0, "expanded": 0, "expansion_failed": 0} + + for i in range(num_samples): + config = generate_random_gemm_config() + config["arch"] = arch # Override with target arch + + result = test_gemm_validation(config, verbose) + + if result["is_valid"]: + gemm_results["valid"] += 1 + else: + gemm_results["invalid"] += 1 + if result["expanded"]: + gemm_results["expanded"] += 1 + else: + gemm_results["expansion_failed"] += 1 + + print("\n GEMM Results:") + print(f" Valid configs: {gemm_results['valid']}") + print(f" Invalid configs: {gemm_results['invalid']}") + print(f" Successfully expanded: {gemm_results['expanded']}") + print(f" Expansion failed: {gemm_results['expansion_failed']}") + + # Test 2: Conv Validation + print("\n" + "-" * 70) + print(" TEST 2: Conv Validation & Wildcard Expansion") + print("-" * 70) + + conv_results = {"valid": 0, "invalid": 0, "expanded": 0, "expansion_failed": 0} + + for i in range(num_samples): + config = generate_random_conv_config() + config["arch"] = arch # Override with target arch + + is_valid, error_msg = validate_conv_kernel_config(config, arch) + + if is_valid: + conv_results["valid"] += 1 + else: + conv_results["invalid"] += 1 + # Try wildcard expansion + wildcard_config = config.copy() + wildcard_config["wave_m"] = -1 + wildcard_config["wave_n"] = -1 + wildcard_config["warp_m"] = -1 + wildcard_config["warp_n"] = -1 + + expanded = expand_conv_declaration_with_arch_filter(wildcard_config, arch) + if expanded: + conv_results["expanded"] += 1 + else: + conv_results["expansion_failed"] += 1 + + print("\n Conv Results:") + print(f" Valid configs: {conv_results['valid']}") + print(f" Invalid configs: {conv_results['invalid']}") + print(f" Successfully expanded: {conv_results['expanded']}") + print(f" Expansion failed: {conv_results['expansion_failed']}") + + # Test 3: Python Auto-Correction + print("\n" + "-" * 70) + print(" TEST 3: Python Auto-Correction (KernelConfig)") + print("-" * 70) + + py_results = test_python_autocorrect(verbose) + + # Test 4: Architecture-specific tests + print("\n" + "-" * 70) + print(" TEST 4: Architecture-Specific Validation") + print("-" * 70) + + arch_test_configs = [ + # fp16 should work on all archs + {"dtype": "fp16", "expected_archs": ARCHS}, + # bf16 works on all archs that have bf16_bf16_fp32 in warp_tile_combos + { + "dtype": "bf16", + "expected_archs": [ + "gfx908", + "gfx90a", + "gfx942", + "gfx950", + "gfx1100", + "gfx1200", + "gfx1201", + ], + }, + # fp8 works on archs that have fp8_fp8_fp32 in warp_tile_combos + { + "dtype": "fp8", + "expected_archs": ["gfx90a", "gfx942", "gfx950", "gfx1200", "gfx1201"], + }, + ] + + for test in arch_test_configs: + dtype = test["dtype"] + print(f"\n Testing {dtype}:") + + for test_arch in ARCHS: + config = { + "name": f"arch_test_{dtype}_{test_arch}", + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": -1, # Wildcard + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": -1, + "pipeline": "*", + "scheduler": "*", + "arch": test_arch, + } + + expanded = expand_declaration_with_arch_filter(config, test_arch) + status = "✓" if expanded else "✗" + expected = test_arch in test["expected_archs"] + match = "OK" if (bool(expanded) == expected) else "MISMATCH" + + if verbose or match == "MISMATCH": + print(f" {test_arch}: {status} ({len(expanded)} configs) [{match}]") + + # Summary + print("\n" + "=" * 70) + print(" STRESS TEST SUMMARY") + print("=" * 70) + print( + f" GEMM: {gemm_results['valid'] + gemm_results['expanded']}/{num_samples} handled" + ) + print( + f" Conv: {conv_results['valid'] + conv_results['expanded']}/{num_samples} handled" + ) + print( + f" Python Auto-Correct: {py_results['passed']}/{py_results['passed'] + py_results['failed']} passed" + ) + + total_success = ( + gemm_results["valid"] + + gemm_results["expanded"] + + conv_results["valid"] + + conv_results["expanded"] + + py_results["passed"] + ) + total_tests = num_samples * 2 + py_results["passed"] + py_results["failed"] + + print(f"\n Overall: {total_success}/{total_tests} tests handled successfully") + print("=" * 70) + + return ( + gemm_results["expansion_failed"] == 0 and conv_results["expansion_failed"] == 0 + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Stress test auto-correction and codegen" + ) + parser.add_argument( + "--arch", + default="gfx942", + choices=ARCHS, + help="Target GPU architecture (default: gfx942)", + ) + parser.add_argument( + "--samples", + type=int, + default=50, + help="Number of random samples to test (default: 50)", + ) + parser.add_argument( + "--verbose", "-v", action="store_true", help="Show detailed output" + ) + parser.add_argument( + "--seed", type=int, default=None, help="Random seed for reproducibility" + ) + + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + + success = run_stress_test(args.arch, args.samples, args.verbose) + + return 0 if success else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/src/dispatcher.cpp b/dispatcher/src/dispatcher.cpp new file mode 100644 index 0000000000..fdb400921e --- /dev/null +++ b/dispatcher/src/dispatcher.cpp @@ -0,0 +1,152 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +Dispatcher::Dispatcher(Registry* registry) + : registry_(registry ? registry : &Registry::instance()), + heuristic_(nullptr), + strategy_(SelectionStrategy::FirstFit) +{ +} + +void Dispatcher::set_heuristic(HeuristicFunction heuristic) +{ + heuristic_ = heuristic; + if(heuristic_) + { + strategy_ = SelectionStrategy::Heuristic; + } +} + +void Dispatcher::set_strategy(SelectionStrategy strategy) { strategy_ = strategy; } + +KernelInstancePtr Dispatcher::select_kernel(const Problem& problem) const +{ + if(!problem.is_valid()) + { + return nullptr; + } + + switch(strategy_) + { + case SelectionStrategy::FirstFit: return select_first_fit(problem); + case SelectionStrategy::Heuristic: return select_heuristic(problem); + default: return nullptr; + } +} + +float Dispatcher::run( + const void* a_ptr, const void* b_ptr, void* c_ptr, const Problem& problem, void* stream) const +{ + return run_fused(a_ptr, b_ptr, c_ptr, nullptr, problem, stream); +} + +float Dispatcher::run_fused(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const +{ + auto kernel = select_kernel(problem); + if(!kernel) + { + std::ostringstream oss; + oss << "No suitable kernel found for problem: M=" << problem.M << " N=" << problem.N + << " K=" << problem.K; + throw std::runtime_error(oss.str()); + } + + return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); +} + +float Dispatcher::run_explicit(const std::string& kernel_id, + const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const +{ + auto kernel = registry_->lookup(kernel_id); + if(!kernel) + { + throw std::runtime_error("Kernel not found: " + kernel_id); + } + + if(!kernel->supports(problem)) + { + std::ostringstream oss; + oss << "Kernel " << kernel_id << " does not support problem: M=" << problem.M + << " N=" << problem.N << " K=" << problem.K; + throw std::runtime_error(oss.str()); + } + + return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); +} + +bool Dispatcher::validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const +{ + auto kernel = select_kernel(problem); + if(!kernel) + { + return false; + } + + return kernel->validate(a_ptr, b_ptr, c_ptr, d_ptrs, problem, tolerance); +} + +KernelInstancePtr Dispatcher::select_first_fit(const Problem& problem) const +{ + auto all_kernels = registry_->get_all(); + + for(const auto& kernel : all_kernels) + { + if(kernel->supports(problem)) + { + return kernel; + } + } + + return nullptr; +} + +KernelInstancePtr Dispatcher::select_heuristic(const Problem& problem) const +{ + if(!heuristic_) + { + // Fall back to first-fit if no heuristic available + return select_first_fit(problem); + } + + // Get ranked list of kernel identifiers from heuristic + auto candidates = heuristic_(problem); + + // Try each candidate in order + for(const auto& kernel_id : candidates) + { + auto kernel = registry_->lookup(kernel_id); + if(kernel && kernel->supports(problem)) + { + return kernel; + } + } + + // If no heuristic candidate works, fall back to first-fit + return select_first_fit(problem); +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/src/registry.cpp b/dispatcher/src/registry.cpp new file mode 100644 index 0000000000..0d83afd613 --- /dev/null +++ b/dispatcher/src/registry.cpp @@ -0,0 +1,288 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/json_export.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include + +namespace ck_tile { +namespace dispatcher { + +Registry::Registry() + : name_("default"), + auto_export_enabled_(false), + auto_export_include_statistics_(true), + auto_export_on_every_registration_(true) +{ +} + +Registry::~Registry() +{ + // Perform auto-export on destruction if enabled (regardless of export_on_every_registration + // setting) + if(auto_export_enabled_) + { + perform_auto_export(); + } +} + +Registry::Registry(Registry&& other) noexcept + : mutex_() // mutex is not movable, create new one + , + kernels_(std::move(other.kernels_)), + name_(std::move(other.name_)), + auto_export_enabled_(other.auto_export_enabled_), + auto_export_filename_(std::move(other.auto_export_filename_)), + auto_export_include_statistics_(other.auto_export_include_statistics_), + auto_export_on_every_registration_(other.auto_export_on_every_registration_) +{ + // Disable auto-export on the moved-from object to prevent double export + other.auto_export_enabled_ = false; +} + +Registry& Registry::operator=(Registry&& other) noexcept +{ + if(this != &other) + { + std::lock_guard lock(mutex_); + std::lock_guard other_lock(other.mutex_); + + kernels_ = std::move(other.kernels_); + name_ = std::move(other.name_); + auto_export_enabled_ = other.auto_export_enabled_; + auto_export_filename_ = std::move(other.auto_export_filename_); + auto_export_include_statistics_ = other.auto_export_include_statistics_; + auto_export_on_every_registration_ = other.auto_export_on_every_registration_; + + // Disable auto-export on the moved-from object + other.auto_export_enabled_ = false; + } + return *this; +} + +bool Registry::register_kernel(KernelInstancePtr instance, Priority priority) +{ + if(!instance) + { + return false; + } + + const std::string identifier = instance->get_key().encode_identifier(); + + bool registered = false; + { + std::lock_guard lock(mutex_); + + auto it = kernels_.find(identifier); + if(it != kernels_.end()) + { + // Kernel with this identifier already exists + // Only replace if new priority is higher + if(priority > it->second.priority) + { + it->second.instance = instance; + it->second.priority = priority; + registered = true; + } + } + else + { + // New kernel, insert it + kernels_[identifier] = RegistryEntry{instance, priority}; + registered = true; + } + } + + // Perform auto-export if enabled and configured to export on every registration + if(registered && auto_export_enabled_ && auto_export_on_every_registration_) + { + perform_auto_export(); + } + + return registered; +} + +KernelInstancePtr Registry::lookup(const std::string& identifier) const +{ + std::lock_guard lock(mutex_); + + auto it = kernels_.find(identifier); + if(it != kernels_.end()) + { + return it->second.instance; + } + + return nullptr; +} + +KernelInstancePtr Registry::lookup(const KernelKey& key) const +{ + return lookup(key.encode_identifier()); +} + +std::vector Registry::get_all() const +{ + std::lock_guard lock(mutex_); + + std::vector result; + result.reserve(kernels_.size()); + + for(const auto& pair : kernels_) + { + result.push_back(pair.second.instance); + } + + return result; +} + +std::vector +Registry::filter(std::function predicate) const +{ + std::lock_guard lock(mutex_); + + std::vector result; + + for(const auto& pair : kernels_) + { + if(predicate(*pair.second.instance)) + { + result.push_back(pair.second.instance); + } + } + + return result; +} + +std::size_t Registry::size() const +{ + std::lock_guard lock(mutex_); + return kernels_.size(); +} + +bool Registry::empty() const +{ + std::lock_guard lock(mutex_); + return kernels_.empty(); +} + +void Registry::clear() +{ + std::lock_guard lock(mutex_); + kernels_.clear(); +} + +const std::string& Registry::get_name() const +{ + std::lock_guard lock(mutex_); + return name_; +} + +void Registry::set_name(const std::string& name) +{ + std::lock_guard lock(mutex_); + name_ = name; +} + +Registry& Registry::instance() +{ + static Registry global_registry; + return global_registry; +} + +std::string Registry::export_json(bool include_statistics) const +{ + return export_registry_json(*this, include_statistics); +} + +bool Registry::export_json_to_file(const std::string& filename, bool include_statistics) const +{ + return export_registry_json_to_file(*this, filename, include_statistics); +} + +void Registry::enable_auto_export(const std::string& filename, + bool include_statistics, + bool export_on_every_registration) +{ + std::lock_guard lock(mutex_); + auto_export_enabled_ = true; + auto_export_filename_ = filename; + auto_export_include_statistics_ = include_statistics; + auto_export_on_every_registration_ = export_on_every_registration; +} + +void Registry::disable_auto_export() +{ + std::lock_guard lock(mutex_); + auto_export_enabled_ = false; +} + +bool Registry::is_auto_export_enabled() const +{ + std::lock_guard lock(mutex_); + return auto_export_enabled_; +} + +void Registry::perform_auto_export() +{ + // Don't hold the lock during file I/O + std::string filename; + bool include_stats; + + { + std::lock_guard lock(mutex_); + if(!auto_export_enabled_) + { + return; + } + filename = auto_export_filename_; + include_stats = auto_export_include_statistics_; + } + + // Export without holding the lock + export_json_to_file(filename, include_stats); +} + +std::size_t Registry::merge_from(const Registry& other, Priority priority) +{ + auto other_kernels = other.get_all(); + std::size_t merged_count = 0; + + for(const auto& kernel : other_kernels) + { + if(register_kernel(kernel, priority)) + { + merged_count++; + } + } + + return merged_count; +} + +std::size_t Registry::filter_by_arch(const std::string& gpu_arch) +{ + ArchFilter filter(gpu_arch); + std::vector to_remove; + + { + std::lock_guard lock(mutex_); + + for(const auto& pair : kernels_) + { + if(!filter.is_valid(pair.second.instance->get_key())) + { + to_remove.push_back(pair.first); + } + } + + for(const auto& key : to_remove) + { + kernels_.erase(key); + } + } + + return to_remove.size(); +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/tests/CMakeLists.txt b/dispatcher/tests/CMakeLists.txt new file mode 100644 index 0000000000..6c20c18c95 --- /dev/null +++ b/dispatcher/tests/CMakeLists.txt @@ -0,0 +1,343 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# ============================================================================= +# CK Tile Dispatcher Tests (C++ and Python) +# ============================================================================= + +cmake_minimum_required(VERSION 3.16) + +# Find Python +find_package(Python3 COMPONENTS Interpreter REQUIRED) + +# ============================================================================= +# Python Tests +# ============================================================================= + +# Auto-correction and validation stress test +add_test( + NAME dispatcher_test_autocorrect + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_autocorrect.py + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_test_autocorrect PROPERTIES + LABELS "dispatcher;python;validation" + TIMEOUT 120 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Verbose version of the test +add_test( + NAME dispatcher_test_autocorrect_verbose + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_autocorrect.py -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_test_autocorrect_verbose PROPERTIES + LABELS "dispatcher;python;validation;verbose" + TIMEOUT 180 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Individual Python Test Categories +add_test( + NAME dispatcher_test_gemm_validation + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestGemmValidation test_autocorrect.TestGemmExpansion -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_gemm_validation PROPERTIES + LABELS "dispatcher;python;gemm;validation" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_python_autocorrect + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestPythonAutoCorrect -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_python_autocorrect PROPERTIES + LABELS "dispatcher;python;autocorrect" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_stress + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestStressRandom -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_stress PROPERTIES + LABELS "dispatcher;python;stress" + TIMEOUT 120 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_arch_support + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestArchitectureSupport -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_arch_support PROPERTIES + LABELS "dispatcher;python;arch" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Stress Test Script +add_test( + NAME dispatcher_stress_test + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/stress_test_autocorrect.py + --arch gfx942 --samples 30 --seed 42 + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_stress_test PROPERTIES + LABELS "dispatcher;python;stress;integration" + TIMEOUT 180 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# ============================================================================= +# Integration Tests (mimic examples) +# ============================================================================= + +# Full integration test suite +add_test( + NAME dispatcher_integration_tests + COMMAND ${Python3_EXECUTABLE} -m pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_examples_integration.py -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_integration_tests PROPERTIES + LABELS "dispatcher;python;integration;examples" + TIMEOUT 600 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Quick integration test (utilities only) +add_test( + NAME dispatcher_integration_quick + COMMAND ${Python3_EXECUTABLE} -m pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_examples_integration.py::TestUtilityImports -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_integration_quick PROPERTIES + LABELS "dispatcher;python;integration;quick" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# GEMM examples integration +add_test( + NAME dispatcher_integration_gemm + COMMAND ${Python3_EXECUTABLE} -m pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_examples_integration.py::TestGemmPythonExamples -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_integration_gemm PROPERTIES + LABELS "dispatcher;python;integration;gemm" + TIMEOUT 300 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# ============================================================================= +# C++ Tests (Google Test) +# ============================================================================= + +# Include Google Test setup +if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../../cmake/gtest.cmake") + include(${CMAKE_CURRENT_SOURCE_DIR}/../../cmake/gtest.cmake) +else() + include(gtest) +endif() + +# Mock kernel instance for testing (shared across tests) +add_library(dispatcher_test_utils STATIC + test_mock_kernel.cpp +) + +target_include_directories(dispatcher_test_utils PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CMAKE_CURRENT_SOURCE_DIR}/../../include +) + +target_link_libraries(dispatcher_test_utils PRIVATE + ck_tile_dispatcher +) + +# Test executables using Google Test +set(TEST_SOURCES + # Core unit tests + test_kernel_key.cpp + test_problem.cpp + test_registry.cpp + test_dispatcher.cpp + test_tile_backend.cpp + + # Extended unit tests (more comprehensive coverage) + test_kernel_key_extended.cpp + test_problem_extended.cpp + test_registry_extended.cpp + test_dispatcher_extended.cpp + + # Regression tests (known issues and edge cases) + test_regression.cpp + + # JSON export tests + test_json_export.cpp +) + +foreach(test_source ${TEST_SOURCES}) + get_filename_component(test_name ${test_source} NAME_WE) + + add_executable(${test_name} ${test_source}) + + target_link_libraries(${test_name} PRIVATE + ck_tile_dispatcher + dispatcher_test_utils + GTest::gtest_main + ) + + target_compile_options(${test_name} PRIVATE + -Wno-global-constructors + -Wno-undef + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) + set_tests_properties(${test_name} PROPERTIES LABELS "dispatcher;cpp;unit") +endforeach() + +# Standalone integration tests (with their own main()) +set(STANDALONE_TESTS + test_minimal.cpp +) + +foreach(test_source ${STANDALONE_TESTS}) + get_filename_component(test_name ${test_source} NAME_WE) + + add_executable(${test_name} ${test_source}) + + target_link_libraries(${test_name} PRIVATE + ck_tile_dispatcher + dispatcher_test_utils + ) + + target_compile_options(${test_name} PRIVATE + -Wno-global-constructors + -Wno-undef + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) + set_tests_properties(${test_name} PROPERTIES LABELS "dispatcher;cpp;integration") +endforeach() + +# ============================================================================= +# Real Kernel Tests (requires generated kernels) +# ============================================================================= + +set(KERNEL_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/../generated_kernels") +set(KERNEL_REGISTRATION_HEADER "${KERNEL_OUTPUT_DIR}/dispatcher_wrappers/register_all_kernels.hpp") +set(CODEGEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py") + +option(BUILD_DISPATCHER_REAL_KERNEL_TESTS "Build tests with real GPU kernels" ON) + +if(BUILD_DISPATCHER_REAL_KERNEL_TESTS AND EXISTS "${CODEGEN_SCRIPT}") + message(STATUS "Setting up real kernel test generation") + + add_custom_command( + OUTPUT ${KERNEL_REGISTRATION_HEADER} + COMMAND ${CMAKE_COMMAND} -E make_directory ${KERNEL_OUTPUT_DIR} + COMMAND ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT} + --output-dir ${KERNEL_OUTPUT_DIR} + --datatype fp16 + --layout rcr + --gpu-target gfx942 + --preselected fp16_rcr_essential + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating CK Tile kernels for real kernel tests..." + VERBATIM + ) + + add_custom_target(generate_test_kernels DEPENDS ${KERNEL_REGISTRATION_HEADER}) + + set(SINGLE_KERNEL_HEADER "${KERNEL_OUTPUT_DIR}/gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16.hpp") + + set(REAL_KERNEL_TESTS + test_real_kernel_simple + test_real_kernel_multi_size + test_real_kernel_performance + test_real_kernel_correctness + test_sanity_ck_tile + ) + + if(EXISTS "${SINGLE_KERNEL_HEADER}") + foreach(test_name ${REAL_KERNEL_TESTS}) + add_executable(${test_name} ${test_name}.cpp) + + add_dependencies(${test_name} generate_test_kernels) + + target_link_libraries(${test_name} PRIVATE + ck_tile_dispatcher + ) + + target_include_directories(${test_name} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${KERNEL_OUTPUT_DIR} + ) + + target_compile_options(${test_name} PRIVATE + -include ${SINGLE_KERNEL_HEADER} + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + + if(hip_FOUND) + target_link_libraries(${test_name} PRIVATE hip::device hip::host) + endif() + + add_test(NAME ${test_name} COMMAND ${test_name}) + set_tests_properties(${test_name} PROPERTIES LABELS "dispatcher;cpp;gpu;kernel") + endforeach() + endif() +endif() + +# ============================================================================= +# Custom Targets +# ============================================================================= + +add_custom_target(run_dispatcher_tests + COMMAND ${CMAKE_CTEST_COMMAND} -L dispatcher --output-on-failure + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running all dispatcher tests" +) + +add_custom_target(test_dispatcher_python + COMMAND ${CMAKE_CTEST_COMMAND} -L "dispatcher;python" --output-on-failure + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running Python dispatcher tests" +) + +add_custom_target(test_dispatcher_cpp + COMMAND ${CMAKE_CTEST_COMMAND} -L "dispatcher;cpp" --output-on-failure + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running C++ dispatcher tests" +) + +# ============================================================================= +# Summary +# ============================================================================= + +message(STATUS "Dispatcher tests configured:") +message(STATUS " Run all: ctest -L dispatcher") +message(STATUS " Run Python: ctest -L 'dispatcher;python' or make test_dispatcher_python") +message(STATUS " Run C++: ctest -L 'dispatcher;cpp' or make test_dispatcher_cpp") +message(STATUS " Run verbose: ctest -R dispatcher_test_autocorrect_verbose") diff --git a/dispatcher/tests/test_autocorrect.py b/dispatcher/tests/test_autocorrect.py new file mode 100644 index 0000000000..0ec3ebda3c --- /dev/null +++ b/dispatcher/tests/test_autocorrect.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Comprehensive Test Suite for Auto-Correction and Validation + +Tests: +1. GEMM validation and wildcard expansion +2. Conv validation and wildcard expansion +3. Python KernelConfig auto-correction +4. Architecture-specific dtype support +5. Edge cases and error handling + +Can be run as: + python3 tests/test_autocorrect.py # Run all tests + python3 tests/test_autocorrect.py -v # Verbose output + python3 tests/test_autocorrect.py TestGemmValidation # Run specific test class + ctest -R test_autocorrect # Via ctest + +Exit codes: + 0 = All tests passed + 1 = Some tests failed +""" + +import sys +import unittest +import random +from pathlib import Path + +# Setup paths +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) +sys.path.insert(0, str(DISPATCHER_DIR / "scripts")) + +# Import modules under test +from compile_gemm_examples import ( # noqa: E402 + validate_kernel_config, + expand_declaration_with_arch_filter, + is_wildcard_declaration, +) +from compile_conv_examples import ( # noqa: E402 + validate_conv_kernel_config, + expand_conv_declaration_with_arch_filter, + is_conv_wildcard_declaration, +) +from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402 + + +# ============================================================================= +# TEST DATA +# ============================================================================= + +VALID_ARCHS = ["gfx90a", "gfx942", "gfx950"] +VALID_DTYPES = ["fp16", "bf16"] +VALID_LAYOUTS = ["rcr", "rrr"] +VALID_PIPELINES = ["compv3", "compv4"] +VALID_SCHEDULERS = ["intrawave"] + +# Known valid wave configs for gfx942 +VALID_WAVE_CONFIGS_GFX942 = [[1, 4, 1], [2, 2, 1], [4, 1, 1]] + +# Known valid warp tiles for fp16 on gfx942 +VALID_WARP_TILES_FP16_GFX942 = [[16, 16, 16], [16, 16, 32], [32, 32, 8], [32, 32, 16]] + + +# ============================================================================= +# GEMM VALIDATION TESTS +# ============================================================================= + + +class TestGemmValidation(unittest.TestCase): + """Test GEMM kernel validation.""" + + def test_valid_config(self): + """Valid configuration should pass validation.""" + config = { + "name": "test_valid", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_kernel_config(config, "gfx942") + self.assertTrue(is_valid, f"Expected valid, got error: {error}") + + def test_invalid_wave_config(self): + """Invalid wave config should fail validation.""" + config = { + "name": "test_invalid_wave", + "dtype_a": "fp16", + "wave_m": 3, # Invalid + "wave_n": 3, # Invalid + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_kernel_config(config, "gfx942") + self.assertFalse(is_valid) + self.assertIn("wave", error.lower()) + + def test_invalid_scheduler(self): + """Invalid scheduler should fail validation.""" + config = { + "name": "test_invalid_scheduler", + "dtype_a": "fp16", + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "epilogue": "cshuffle", + "scheduler": "interwave", # Invalid with compv4+cshuffle + } + is_valid, error = validate_kernel_config(config, "gfx942") + self.assertFalse(is_valid) + self.assertIn("trait", error.lower()) + + def test_wildcard_skips_validation(self): + """Wildcard declarations should skip validation.""" + config = { + "name": "test_wildcard", + "dtype_a": "fp16", + "wave_m": -1, # Wildcard + "wave_n": -1, # Wildcard + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + self.assertTrue(is_wildcard_declaration(config)) + is_valid, _ = validate_kernel_config(config, "gfx942") + self.assertTrue(is_valid) + + def test_unsupported_arch(self): + """Unsupported architecture should fail validation.""" + config = { + "name": "test_bad_arch", + "dtype_a": "fp16", + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_kernel_config(config, "gfx_invalid") + self.assertFalse(is_valid) + self.assertIn("unsupported", error.lower()) + + +class TestGemmExpansion(unittest.TestCase): + """Test GEMM wildcard expansion.""" + + def test_wave_expansion(self): + """Wave wildcard should expand to valid configs.""" + config = { + "name": "test_wave_expand", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": -1, # Wildcard + "wave_n": -1, # Wildcard + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "Should expand to at least one config") + + # All expanded configs should be valid + for exp in expanded: + is_valid, error = validate_kernel_config(exp, "gfx942") + self.assertTrue(is_valid, f"Expanded config invalid: {error}") + + def test_full_wildcard_expansion(self): + """Full wildcard should expand to multiple valid configs.""" + config = { + "name": "test_full_wildcard", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater( + len(expanded), 1, "Full wildcard should expand to multiple configs" + ) + + def test_explicit_config_not_expanded(self): + """Explicit (non-wildcard) config should not expand.""" + config = { + "name": "test_explicit", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertEqual(len(expanded), 1, "Explicit config should not expand") + + +# ============================================================================= +# CONV VALIDATION TESTS +# ============================================================================= + + +class TestConvValidation(unittest.TestCase): + """Test Conv kernel validation.""" + + def test_valid_conv_config(self): + """Valid conv configuration should pass validation.""" + config = { + "name": "test_valid_conv", + "dtype": "fp16", + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": 128, + "tile_c": 128, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_conv_kernel_config(config, "gfx942") + self.assertTrue(is_valid, f"Expected valid, got error: {error}") + + def test_invalid_conv_wave(self): + """Invalid wave config should fail conv validation.""" + config = { + "name": "test_invalid_conv_wave", + "dtype": "fp16", + "wave_m": 5, # Invalid + "wave_n": 5, # Invalid + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_conv_kernel_config(config, "gfx942") + self.assertFalse(is_valid) + self.assertIn("wave", error.lower()) + + def test_conv_wildcard_detection(self): + """Should correctly detect conv wildcards.""" + wildcard_config = { + "wave_m": -1, + "wave_n": 2, + "warp_m": 32, + "warp_n": 32, + "pipeline": "compv4", + "scheduler": "intrawave", + } + self.assertTrue(is_conv_wildcard_declaration(wildcard_config)) + + explicit_config = { + "wave_m": 2, + "wave_n": 2, + "warp_m": 32, + "warp_n": 32, + "pipeline": "compv4", + "scheduler": "intrawave", + } + self.assertFalse(is_conv_wildcard_declaration(explicit_config)) + + +class TestConvExpansion(unittest.TestCase): + """Test Conv wildcard expansion.""" + + def test_conv_wave_expansion(self): + """Conv wave wildcard should expand to valid configs.""" + config = { + "name": "test_conv_wave_expand", + "dtype": "fp16", + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": 128, + "tile_c": 128, + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + expanded = expand_conv_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "Should expand to at least one config") + + +# ============================================================================= +# PYTHON AUTO-CORRECTION TESTS +# ============================================================================= + + +class TestPythonAutoCorrect(unittest.TestCase): + """Test Python KernelConfig auto-correction.""" + + def test_autocorrect_invalid_wave(self): + """Auto-correction should fix invalid wave config.""" + config = KernelConfig() + config.dtype_a = "fp16" + config.dtype_b = "fp16" + config.dtype_c = "fp16" + config.dtype_acc = "fp32" + config.layout_a = "row" + config.layout_b = "col" + config.layout_c = "row" + config.tile_m = 128 + config.tile_n = 128 + config.tile_k = 32 + config.wave_m = 1 # May be invalid + config.wave_n = 1 # May be invalid + config.wave_k = 1 + config.warp_m = 32 + config.warp_n = 32 + config.warp_k = 16 + config.pipeline = "compv4" + config.scheduler = "intrawave" + config.gfx_arch = "gfx942" + + corrected, was_modified, corrections = auto_correct_kernel_config( + config, verbose=False + ) + + # Should either be valid or corrected + self.assertIsNotNone(corrected) + if was_modified: + self.assertGreater(len(corrections), 0) + + def test_autocorrect_returns_three_values(self): + """Auto-correction should return (config, was_modified, corrections).""" + config = KernelConfig() + config.dtype_a = "fp16" + config.dtype_b = "fp16" + config.dtype_c = "fp16" + config.dtype_acc = "fp32" + config.layout_a = "row" + config.layout_b = "col" + config.layout_c = "row" + config.tile_m = 128 + config.tile_n = 128 + config.tile_k = 32 + config.wave_m = 2 + config.wave_n = 2 + config.wave_k = 1 + config.warp_m = 32 + config.warp_n = 32 + config.warp_k = 16 + config.pipeline = "compv4" + config.scheduler = "intrawave" + config.gfx_arch = "gfx942" + + result = auto_correct_kernel_config(config, verbose=False) + + self.assertEqual(len(result), 3, "Should return 3 values") + corrected, was_modified, corrections = result + self.assertIsInstance(was_modified, bool) + self.assertIsInstance(corrections, list) + + +# ============================================================================= +# STRESS TESTS +# ============================================================================= + + +class TestStressRandom(unittest.TestCase): + """Stress test with random configurations.""" + + def test_random_gemm_configs(self): + """Random GEMM configs should either validate or expand successfully.""" + random.seed(42) # Reproducible + + dtypes = ["fp16", "bf16"] + layouts = ["rcr", "rrr"] + tiles = [(64, 64, 32), (128, 128, 32), (256, 256, 64)] + waves = [(1, 1, 1), (2, 2, 1), (1, 4, 1), (3, 3, 1)] # Some invalid + warps = [(16, 16, 16), (32, 32, 16), (48, 48, 24)] # Some invalid + pipelines = ["compv3", "compv4", "invalid"] + schedulers = ["intrawave", "interwave"] + + success_count = 0 + total_count = 30 + + for _ in range(total_count): + config = { + "name": "random_test", + "dtype_a": random.choice(dtypes), + "dtype_b": random.choice(dtypes), + "dtype_c": random.choice(dtypes), + "layout": random.choice(layouts), + "tile_m": random.choice(tiles)[0], + "tile_n": random.choice(tiles)[1], + "tile_k": random.choice(tiles)[2], + "wave_m": random.choice(waves)[0], + "wave_n": random.choice(waves)[1], + "wave_k": random.choice(waves)[2], + "warp_m": random.choice(warps)[0], + "warp_n": random.choice(warps)[1], + "warp_k": random.choice(warps)[2], + "pipeline": random.choice(pipelines), + "scheduler": random.choice(schedulers), + } + + is_valid, _ = validate_kernel_config(config, "gfx942") + + if is_valid: + success_count += 1 + else: + # Try wildcard expansion + wildcard = config.copy() + wildcard["wave_m"] = -1 + wildcard["wave_n"] = -1 + wildcard["warp_m"] = -1 + wildcard["warp_n"] = -1 + wildcard["pipeline"] = "*" + wildcard["scheduler"] = "*" + + expanded = expand_declaration_with_arch_filter(wildcard, "gfx942") + if expanded: + success_count += 1 + + # At least 50% should be handleable + self.assertGreater( + success_count / total_count, + 0.5, + f"Only {success_count}/{total_count} configs were handleable", + ) + + def test_random_conv_configs(self): + """Random Conv configs should either validate or expand successfully.""" + random.seed(42) + + dtypes = ["fp16", "bf16"] + tiles = [(64, 64), (128, 128), (256, 256)] + waves = [(2, 2, 1), (1, 4, 1), (3, 3, 1)] + warps = [(16, 16, 16), (32, 32, 16)] + + success_count = 0 + total_count = 20 + + for _ in range(total_count): + config = { + "name": "random_conv_test", + "dtype": random.choice(dtypes), + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": random.choice(tiles)[0], + "tile_c": random.choice(tiles)[1], + "wave_m": random.choice(waves)[0], + "wave_n": random.choice(waves)[1], + "wave_k": random.choice(waves)[2], + "warp_m": random.choice(warps)[0], + "warp_n": random.choice(warps)[1], + "warp_k": random.choice(warps)[2], + "pipeline": "compv4", + "scheduler": "intrawave", + } + + is_valid, _ = validate_conv_kernel_config(config, "gfx942") + + if is_valid: + success_count += 1 + else: + # Try wildcard expansion + wildcard = config.copy() + wildcard["wave_m"] = -1 + wildcard["wave_n"] = -1 + wildcard["warp_m"] = -1 + wildcard["warp_n"] = -1 + + expanded = expand_conv_declaration_with_arch_filter(wildcard, "gfx942") + if expanded: + success_count += 1 + + self.assertGreater( + success_count / total_count, + 0.5, + f"Only {success_count}/{total_count} conv configs were handleable", + ) + + +# ============================================================================= +# ARCHITECTURE TESTS +# ============================================================================= + + +class TestArchitectureSupport(unittest.TestCase): + """Test architecture-specific support.""" + + def test_gfx942_fp16_support(self): + """gfx942 should support fp16.""" + config = { + "dtype_a": "fp16", + "wave_m": -1, + "wave_n": -1, + "warp_m": -1, + "warp_n": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "gfx942 should support fp16") + + def test_gfx942_bf16_support(self): + """gfx942 should support bf16.""" + config = { + "dtype_a": "bf16", + "wave_m": -1, + "wave_n": -1, + "warp_m": -1, + "warp_n": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "gfx942 should support bf16") + + def test_gfx90a_support(self): + """gfx90a should support fp16.""" + config = { + "dtype_a": "fp16", + "wave_m": -1, + "wave_n": -1, + "warp_m": -1, + "warp_n": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx90a") + self.assertGreater(len(expanded), 0, "gfx90a should support fp16") + + +# ============================================================================= +# MAIN +# ============================================================================= + + +def main(): + """Run tests.""" + # Parse args for verbosity + verbosity = 2 if "-v" in sys.argv or "--verbose" in sys.argv else 1 + + # Create test suite + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add all test classes + suite.addTests(loader.loadTestsFromTestCase(TestGemmValidation)) + suite.addTests(loader.loadTestsFromTestCase(TestGemmExpansion)) + suite.addTests(loader.loadTestsFromTestCase(TestConvValidation)) + suite.addTests(loader.loadTestsFromTestCase(TestConvExpansion)) + suite.addTests(loader.loadTestsFromTestCase(TestPythonAutoCorrect)) + suite.addTests(loader.loadTestsFromTestCase(TestStressRandom)) + suite.addTests(loader.loadTestsFromTestCase(TestArchitectureSupport)) + + # Run tests + runner = unittest.TextTestRunner(verbosity=verbosity) + result = runner.run(suite) + + # Return exit code + return 0 if result.wasSuccessful() else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/tests/test_dispatcher.cpp b/dispatcher/tests/test_dispatcher.cpp new file mode 100644 index 0000000000..1e3893756c --- /dev/null +++ b/dispatcher/tests/test_dispatcher.cpp @@ -0,0 +1,296 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for Dispatcher using Google Test + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +class DispatcherTest : public ::testing::Test +{ + protected: + void SetUp() override + { + // Clear registry before each test + Registry::instance().clear(); + } + + void TearDown() override + { + // Clean up after each test + Registry::instance().clear(); + } +}; + +TEST_F(DispatcherTest, SelectKernelFirstFit) +{ + Dispatcher dispatcher; + + // Register kernels + auto key1 = make_test_key(256); + auto key2 = make_test_key(128); + auto kernel1 = std::make_shared(key1, "kernel1"); + auto kernel2 = std::make_shared(key2, "kernel2"); + + Registry::instance().register_kernel(kernel1); + Registry::instance().register_kernel(kernel2); + + // Select kernel for valid problem + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + // Should select a kernel that supports the problem + // (order is not guaranteed, so just verify one is selected) + EXPECT_TRUE(selected->get_name() == "kernel1" || selected->get_name() == "kernel2"); + EXPECT_TRUE(selected->supports(problem)); +} + +TEST_F(DispatcherTest, SelectKernelInvalidProblem) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + // Invalid problem + Problem invalid_problem(0, 0, 0); + auto selected = dispatcher.select_kernel(invalid_problem); + + EXPECT_EQ(selected, nullptr); +} + +TEST_F(DispatcherTest, SelectKernelNoMatch) +{ + Dispatcher dispatcher; + + // Register kernel that doesn't support the problem + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1", false); + Registry::instance().register_kernel(kernel); + + // Problem with dimensions not divisible by tile size + Problem problem(100, 100, 100); // Not divisible by 256 + auto selected = dispatcher.select_kernel(problem); + + EXPECT_EQ(selected, nullptr); +} + +TEST_F(DispatcherTest, SelectKernelHeuristic) +{ + Dispatcher dispatcher; + + // Register kernels + auto key1 = make_test_key(256); + auto key2 = make_test_key(128); + auto kernel1 = std::make_shared(key1, "kernel1"); + auto kernel2 = std::make_shared(key2, "kernel2"); + + Registry::instance().register_kernel(kernel1); + Registry::instance().register_kernel(kernel2); + + // Set heuristic that prefers kernel2 + dispatcher.set_heuristic([](const Problem&) { + std::vector candidates; + auto key2 = make_test_key(128); + candidates.push_back(key2.encode_identifier()); + auto key1 = make_test_key(256); + candidates.push_back(key1.encode_identifier()); + return candidates; + }); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel2"); +} + +TEST_F(DispatcherTest, SelectKernelHeuristicFallback) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + // Set heuristic that returns non-existent kernel + dispatcher.set_heuristic( + [](const Problem&) { return std::vector{"nonexistent_kernel"}; }); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + // Should fall back to first-fit + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel1"); +} + +TEST_F(DispatcherTest, RunBasic) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + + // Mock pointers (not actually used) + float a[1], b[1], c[1]; + + float time_ms = dispatcher.run(a, b, c, problem); + + EXPECT_GT(time_ms, 0.0f); + EXPECT_EQ(kernel->get_execution_count(), 1); +} + +TEST_F(DispatcherTest, RunNoKernel) +{ + Dispatcher dispatcher; + + // No kernels registered + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + EXPECT_THROW((void)dispatcher.run(a, b, c, problem), std::runtime_error); +} + +TEST_F(DispatcherTest, RunExplicit) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + std::string kernel_id = key.encode_identifier(); + + float a[1], b[1], c[1]; + + float time_ms = dispatcher.run_explicit(kernel_id, a, b, c, nullptr, problem); + + EXPECT_GT(time_ms, 0.0f); + EXPECT_EQ(kernel->get_execution_count(), 1); +} + +TEST_F(DispatcherTest, RunExplicitNotFound) +{ + Dispatcher dispatcher; + + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + EXPECT_THROW((void)dispatcher.run_explicit("nonexistent", a, b, c, nullptr, problem), + std::runtime_error); +} + +TEST_F(DispatcherTest, RunExplicitNotSupported) +{ + Dispatcher dispatcher; + + // Register kernel that doesn't support the problem + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1", false); + Registry::instance().register_kernel(kernel); + + Problem problem(100, 100, 100); // Not divisible by 256 + std::string kernel_id = key.encode_identifier(); + + float a[1], b[1], c[1]; + + EXPECT_THROW((void)dispatcher.run_explicit(kernel_id, a, b, c, nullptr, problem), + std::runtime_error); +} + +TEST_F(DispatcherTest, Validate) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + bool valid = dispatcher.validate(a, b, c, nullptr, problem); + + EXPECT_TRUE(valid); +} + +TEST_F(DispatcherTest, ValidateNoKernel) +{ + Dispatcher dispatcher; + + // No kernels registered + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + bool valid = dispatcher.validate(a, b, c, nullptr, problem); + + EXPECT_FALSE(valid); +} + +TEST_F(DispatcherTest, StrategySelection) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + + // Test FirstFit strategy + dispatcher.set_strategy(Dispatcher::SelectionStrategy::FirstFit); + auto selected1 = dispatcher.select_kernel(problem); + ASSERT_NE(selected1, nullptr); + + // Test Heuristic strategy (without heuristic function - should fallback) + dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic); + auto selected2 = dispatcher.select_kernel(problem); + ASSERT_NE(selected2, nullptr); +} + +TEST_F(DispatcherTest, CustomRegistry) +{ + // Create custom registry instance (not singleton) + // Note: This requires Registry to allow non-singleton instances + // For now, we'll test with a separate registry instance + // In practice, custom registry would be created differently + + // Since Registry is singleton-only, we'll test that dispatcher + // can work with the singleton registry + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + registry.register_kernel(kernel); + + // Dispatcher defaults to singleton registry + Dispatcher dispatcher; + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel1"); +} diff --git a/dispatcher/tests/test_dispatcher_extended.cpp b/dispatcher/tests/test_dispatcher_extended.cpp new file mode 100644 index 0000000000..e8d7e4b5d1 --- /dev/null +++ b/dispatcher/tests/test_dispatcher_extended.cpp @@ -0,0 +1,499 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for Dispatcher - covers selection strategies, heuristics, edge cases + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; +using SelectionStrategy = Dispatcher::SelectionStrategy; + +// ============================================================================= +// Basic Dispatcher Tests +// ============================================================================= + +class DispatcherBasicTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(DispatcherBasicTest, DefaultConstruction) +{ + Dispatcher dispatcher; + // Should not crash + SUCCEED(); +} + +TEST_F(DispatcherBasicTest, SelectKernelEmpty) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + auto kernel = dispatcher.select_kernel(problem); + EXPECT_EQ(kernel, nullptr); +} + +TEST_F(DispatcherBasicTest, SelectKernelSingle) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "test_kernel"); +} + +TEST_F(DispatcherBasicTest, SelectKernelMultiple) +{ + // Register multiple kernels + for(int tile : {128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + // Should select one of the registered kernels + EXPECT_TRUE(selected->get_name() == "kernel_128" || selected->get_name() == "kernel_256" || + selected->get_name() == "kernel_512"); +} + +// ============================================================================= +// Selection Strategy Tests +// ============================================================================= + +class SelectionStrategyTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + // Register kernels with different tile sizes + for(int tile : {128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(SelectionStrategyTest, FirstFitStrategy) +{ + Dispatcher dispatcher; + dispatcher.set_strategy(SelectionStrategy::FirstFit); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + // FirstFit returns first matching kernel +} + +TEST_F(SelectionStrategyTest, HeuristicStrategy) +{ + Dispatcher dispatcher; + + // Set heuristic that prefers larger tiles for large problems + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + if(p.M >= 1024 && p.N >= 1024) + { + // For large problems, prefer 512 tile + auto key = make_test_key(512); + return {key.encode_identifier()}; + } + // For small problems, prefer 128 tile + auto key = make_test_key(128); + return {key.encode_identifier()}; + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + // Large problem should get 512 tile + Problem large_problem(2048, 2048, 2048); + auto selected = dispatcher.select_kernel(large_problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_512"); + + // Small problem should get 128 tile + Problem small_problem(256, 256, 256); + selected = dispatcher.select_kernel(small_problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_128"); +} + +TEST_F(SelectionStrategyTest, HeuristicWithFallback) +{ + Dispatcher dispatcher; + + // Heuristic returns non-existent kernel first, then valid one + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + auto key = make_test_key(256); + return {"nonexistent_kernel", key.encode_identifier()}; + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_256"); +} + +TEST_F(SelectionStrategyTest, SwitchBetweenStrategies) +{ + Dispatcher dispatcher; + + // Start with FirstFit + dispatcher.set_strategy(SelectionStrategy::FirstFit); + + Problem problem(1024, 1024, 1024); + auto selected1 = dispatcher.select_kernel(problem); + ASSERT_NE(selected1, nullptr); + + // Switch to Heuristic + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + auto key = make_test_key(256); + return {key.encode_identifier()}; + }); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + auto selected2 = dispatcher.select_kernel(problem); + ASSERT_NE(selected2, nullptr); +} + +// ============================================================================= +// Heuristic Function Tests +// ============================================================================= + +class HeuristicTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + for(int tile : {64, 128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(HeuristicTest, SizeBasedHeuristic) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + std::vector candidates; + + // Problem-size based selection + int size = p.M * p.N * p.K; + + if(size >= 1024 * 1024 * 1024) + { + candidates.push_back(make_test_key(512).encode_identifier()); + candidates.push_back(make_test_key(256).encode_identifier()); + } + else if(size >= 256 * 256 * 256) + { + candidates.push_back(make_test_key(256).encode_identifier()); + candidates.push_back(make_test_key(128).encode_identifier()); + } + else + { + candidates.push_back(make_test_key(64).encode_identifier()); + candidates.push_back(make_test_key(128).encode_identifier()); + } + + return candidates; + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + // Large problem + auto selected = dispatcher.select_kernel(Problem(1024, 1024, 1024)); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_512"); + + // Medium problem + selected = dispatcher.select_kernel(Problem(256, 256, 256)); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_256"); + + // Small problem + selected = dispatcher.select_kernel(Problem(64, 64, 64)); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_64"); +} + +TEST_F(HeuristicTest, EmptyHeuristicFallsBackToFirstFit) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {}; // Empty list + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + // Should fall back to FirstFit + ASSERT_NE(selected, nullptr); +} + +TEST_F(HeuristicTest, InvalidHeuristicFallsBackToFirstFit) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {"invalid_kernel_1", "invalid_kernel_2"}; // All invalid + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + // Should fall back to FirstFit + ASSERT_NE(selected, nullptr); +} + +// ============================================================================= +// Dispatcher with Custom Registry Tests +// ============================================================================= + +class DispatcherCustomRegistryTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(DispatcherCustomRegistryTest, UseCustomRegistry) +{ + Registry custom_registry; + custom_registry.set_name("custom"); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "custom_kernel"); + custom_registry.register_kernel(kernel); + + Dispatcher dispatcher(&custom_registry); + Problem problem(1024, 1024, 1024); + + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "custom_kernel"); +} + +TEST_F(DispatcherCustomRegistryTest, CustomRegistryIsolation) +{ + Registry custom_registry; + + auto key_custom = make_test_key(256); + auto key_global = make_test_key(512); + + custom_registry.register_kernel( + std::make_shared(key_custom, "custom_kernel")); + Registry::instance().register_kernel( + std::make_shared(key_global, "global_kernel")); + + Dispatcher custom_dispatcher(&custom_registry); + Dispatcher global_dispatcher; + + Problem problem(1024, 1024, 1024); + + auto custom_selected = custom_dispatcher.select_kernel(problem); + auto global_selected = global_dispatcher.select_kernel(problem); + + ASSERT_NE(custom_selected, nullptr); + ASSERT_NE(global_selected, nullptr); + + EXPECT_EQ(custom_selected->get_name(), "custom_kernel"); + EXPECT_EQ(global_selected->get_name(), "global_kernel"); +} + +// ============================================================================= +// Edge Cases Tests +// ============================================================================= + +class DispatcherEdgeCasesTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(DispatcherEdgeCasesTest, InvalidProblem) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + + // Zero dimensions + Problem invalid(0, 1024, 1024); + EXPECT_FALSE(invalid.is_valid()); + + // The dispatcher should still attempt selection + // (validation is up to the kernel's supports() method) +} + +TEST_F(DispatcherEdgeCasesTest, KernelDoesNotSupportProblem) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "selective_kernel", false); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + + // Problem not divisible by tile size - kernel doesn't support it + Problem problem(1000, 1000, 1000); // Not divisible by 256 + + auto selected = dispatcher.select_kernel(problem); + // Should return nullptr since kernel doesn't support this problem + EXPECT_EQ(selected, nullptr); +} + +TEST_F(DispatcherEdgeCasesTest, MultipleSelectionsConsistent) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Multiple selections should return the same kernel + auto selected1 = dispatcher.select_kernel(problem); + auto selected2 = dispatcher.select_kernel(problem); + auto selected3 = dispatcher.select_kernel(problem); + + ASSERT_NE(selected1, nullptr); + EXPECT_EQ(selected1, selected2); + EXPECT_EQ(selected2, selected3); +} + +// ============================================================================= +// Validate Method Tests +// ============================================================================= + +class DispatcherValidateTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + kernel_ = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel_); + } + + void TearDown() override { Registry::instance().clear(); } + + std::shared_ptr kernel_; +}; + +TEST_F(DispatcherValidateTest, ValidateWithMockKernel) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // MockKernelInstance always validates successfully + bool valid = dispatcher.validate(nullptr, nullptr, nullptr, nullptr, problem); + + // This depends on implementation - mock returns true + // Real validation would need actual data +} + +// ============================================================================= +// Run Method Tests (with mock) +// ============================================================================= + +class DispatcherRunTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + kernel_ = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel_); + } + + void TearDown() override { Registry::instance().clear(); } + + std::shared_ptr kernel_; +}; + +TEST_F(DispatcherRunTest, RunWithMockKernel) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Mock run (with null pointers - mock doesn't use them) + float time = dispatcher.run(nullptr, nullptr, nullptr, problem); + + // Mock kernel returns 1.0f + EXPECT_FLOAT_EQ(time, 1.0f); + + // Verify execution count + EXPECT_EQ(kernel_->get_execution_count(), 1); +} + +TEST_F(DispatcherRunTest, MultipleRuns) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + for(int i = 0; i < 10; i++) + { + (void)dispatcher.run(nullptr, nullptr, nullptr, problem); + } + + EXPECT_EQ(kernel_->get_execution_count(), 10); +} + +TEST_F(DispatcherRunTest, RunWithNoKernelThrows) +{ + Registry::instance().clear(); + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Should throw when no kernel found + EXPECT_THROW((void)dispatcher.run(nullptr, nullptr, nullptr, problem), std::runtime_error); +} diff --git a/dispatcher/tests/test_examples_integration.py b/dispatcher/tests/test_examples_integration.py new file mode 100644 index 0000000000..cfd18a3305 --- /dev/null +++ b/dispatcher/tests/test_examples_integration.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Integration tests that verify examples work correctly. + +These tests mimic the examples to ensure they continue working. +Run with: pytest test_examples_integration.py -v +""" + +import unittest +import subprocess +import sys +import os +from pathlib import Path + +# Get paths +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_ROOT = SCRIPT_DIR.parent +EXAMPLES_DIR = DISPATCHER_ROOT / "examples" +BUILD_DIR = DISPATCHER_ROOT / "build" +PYTHON_DIR = DISPATCHER_ROOT / "python" + +# Add python utilities to path +sys.path.insert(0, str(PYTHON_DIR)) + + +def run_python_example( + example_path: Path, timeout: int = 120 +) -> subprocess.CompletedProcess: + """Run a Python example and capture output.""" + env = os.environ.copy() + env["PYTHONPATH"] = str(PYTHON_DIR) + + return subprocess.run( + [sys.executable, str(example_path)], + capture_output=True, + text=True, + timeout=timeout, + cwd=example_path.parent, + env=env, + ) + + +def run_cpp_example( + example_name: str, timeout: int = 60 +) -> subprocess.CompletedProcess: + """Run a C++ example and capture output.""" + example_path = BUILD_DIR / "examples" / example_name + + if not example_path.exists(): + return None + + return subprocess.run( + [str(example_path)], + capture_output=True, + text=True, + timeout=timeout, + ) + + +class TestGemmPythonExamples(unittest.TestCase): + """Test GEMM Python examples.""" + + @classmethod + def setUpClass(cls): + """Check if examples directory exists.""" + cls.gemm_examples_dir = EXAMPLES_DIR / "gemm" / "python" + if not cls.gemm_examples_dir.exists(): + raise unittest.SkipTest("GEMM Python examples not found") + + def test_01_basic_gemm(self): + """Test basic GEMM example.""" + example = self.gemm_examples_dir / "01_basic_gemm.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_02_batch_gemm(self): + """Test batch GEMM example.""" + example = self.gemm_examples_dir / "02_batch_gemm.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_03_benchmark(self): + """Test benchmark example.""" + example = self.gemm_examples_dir / "03_benchmark.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_04_validation(self): + """Test validation example.""" + example = self.gemm_examples_dir / "04_validation.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + # Should pass validation + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestConvPythonExamples(unittest.TestCase): + """Test Conv Python examples.""" + + @classmethod + def setUpClass(cls): + """Check if examples directory exists.""" + cls.conv_examples_dir = EXAMPLES_DIR / "conv" / "python" + if not cls.conv_examples_dir.exists(): + raise unittest.SkipTest("Conv Python examples not found") + + def test_01_basic_conv(self): + """Test basic conv example.""" + example = self.conv_examples_dir / "01_basic_conv.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_02_conv2d_fwd(self): + """Test 2D forward conv example.""" + example = self.conv_examples_dir / "02_conv2d_fwd.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_03_conv3d_fwd(self): + """Test 3D forward conv example.""" + example = self.conv_examples_dir / "03_conv3d_fwd.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_07_validation(self): + """Test validation example.""" + example = self.conv_examples_dir / "07_validation.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestGemmCppExamples(unittest.TestCase): + """Test GEMM C++ examples.""" + + @classmethod + def setUpClass(cls): + """Check if build directory exists.""" + cls.examples_dir = BUILD_DIR / "examples" + if not cls.examples_dir.exists(): + raise unittest.SkipTest("C++ examples not built") + + def test_gemm_01_basic(self): + """Test basic GEMM C++ example.""" + result = run_cpp_example("gemm_01_basic") + if result is None: + self.skipTest("gemm_01_basic not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_gemm_02_multi_size(self): + """Test multi-size GEMM C++ example.""" + result = run_cpp_example("gemm_02_multi_size") + if result is None: + self.skipTest("gemm_02_multi_size not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_gemm_04_validation(self): + """Test validation GEMM C++ example.""" + result = run_cpp_example("gemm_04_validation") + if result is None: + self.skipTest("gemm_04_validation not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestConvCppExamples(unittest.TestCase): + """Test Conv C++ examples.""" + + @classmethod + def setUpClass(cls): + """Check if build directory exists.""" + cls.examples_dir = BUILD_DIR / "examples" + if not cls.examples_dir.exists(): + raise unittest.SkipTest("C++ examples not built") + + def test_conv_01_forward(self): + """Test forward conv C++ example.""" + result = run_cpp_example("conv_01_forward") + if result is None: + self.skipTest("conv_01_forward not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_conv_02_validation(self): + """Test validation conv C++ example.""" + result = run_cpp_example("conv_02_validation") + if result is None: + self.skipTest("conv_02_validation not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestUtilityImports(unittest.TestCase): + """Test that utility modules can be imported.""" + + def test_import_ctypes_utils(self): + """Test importing ctypes_utils.""" + try: + from ctypes_utils import KernelConfig, setup_gemm_dispatcher # noqa: F401 + + self.assertTrue(True) + except ImportError as e: + self.fail(f"Failed to import ctypes_utils: {e}") + + def test_import_conv_utils(self): + """Test importing conv_utils.""" + try: + from conv_utils import ConvSignature, ConvAlgorithm, ConvProblem # noqa: F401 + + self.assertTrue(True) + except ImportError as e: + self.fail(f"Failed to import conv_utils: {e}") + + def test_kernel_config_creation(self): + """Test creating a KernelConfig.""" + from ctypes_utils import KernelConfig + + config = KernelConfig( + dtype_a="fp16", + dtype_b="fp16", + dtype_c="fp16", + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + ) + + self.assertEqual(config.dtype_a, "fp16") + self.assertEqual(config.layout_a, "row") + + def test_conv_signature_creation(self): + """Test creating a ConvSignature.""" + from conv_utils import ConvSignature + + sig = ConvSignature( + dtype_in="fp16", + dtype_wei="fp16", + dtype_out="fp16", + dtype_acc="fp32", + layout="nhwgc", + direction="forward", + num_dims=2, + ) + + self.assertEqual(sig.dtype_in, "fp16") + self.assertEqual(sig.direction, "forward") + + +class TestAutoCorrection(unittest.TestCase): + """Test auto-correction functionality.""" + + def test_gemm_auto_correct(self): + """Test GEMM auto-correction.""" + from ctypes_utils import KernelConfig, auto_correct_kernel_config + + # Create a config with invalid wave config + config = KernelConfig( + dtype_a="fp16", + dtype_b="fp16", + dtype_c="fp16", + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + wave_m=99, # Invalid + wave_n=99, # Invalid + wave_k=99, # Invalid + ) + + corrected, was_modified, corrections = auto_correct_kernel_config(config) + + self.assertTrue(was_modified, "Config should be modified") + self.assertGreater(len(corrections), 0, "Should have corrections") + + def test_conv_auto_correct(self): + """Test Conv auto-correction.""" + from conv_utils import auto_correct_conv_config + + # Call with invalid wave config parameters + corrected, was_modified, corrections = auto_correct_conv_config( + wave_m=99, # Invalid + wave_n=99, # Invalid + wave_k=99, # Invalid + dtype="fp16", + arch="gfx942", + ) + + self.assertTrue(was_modified, "Config should be modified") + self.assertGreater(len(corrections), 0, "Should have corrections") + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_json_export.cpp b/dispatcher/tests/test_json_export.cpp new file mode 100644 index 0000000000..4392729554 --- /dev/null +++ b/dispatcher/tests/test_json_export.cpp @@ -0,0 +1,448 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for JSON export functionality + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/json_export.hpp" +#include "test_mock_kernel.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +// ============================================================================= +// Basic Export Tests +// ============================================================================= + +class JSONExportBasicTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONExportBasicTest, ExportEmptyRegistry) +{ + std::string json = Registry::instance().export_json(false); + + EXPECT_FALSE(json.empty()); + EXPECT_NE(json.find("\"kernels\""), std::string::npos); + // Empty registry should still produce valid JSON with kernels section +} + +TEST_F(JSONExportBasicTest, ExportSingleKernel) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(false); + + EXPECT_FALSE(json.empty()); + EXPECT_NE(json.find("\"test_kernel\""), std::string::npos); +} + +TEST_F(JSONExportBasicTest, ExportMultipleKernels) +{ + for(int i = 0; i < 5; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + std::string json = Registry::instance().export_json(false); + + // Should contain all kernel names + for(int i = 0; i < 5; i++) + { + EXPECT_NE(json.find("\"kernel_" + std::to_string(i) + "\""), std::string::npos); + } +} + +// ============================================================================= +// Export with Statistics Tests +// ============================================================================= + +class JSONExportStatisticsTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONExportStatisticsTest, ExportWithStatistics) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); // Include statistics + + EXPECT_NE(json.find("\"statistics\""), std::string::npos); + EXPECT_NE(json.find("\"by_datatype\""), std::string::npos); + EXPECT_NE(json.find("\"by_pipeline\""), std::string::npos); +} + +TEST_F(JSONExportStatisticsTest, ExportWithoutStatistics) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(false); // No statistics + + // Statistics section might be minimal or absent + EXPECT_NE(json.find("\"kernels\""), std::string::npos); +} + +// ============================================================================= +// Metadata Tests +// ============================================================================= + +class JSONExportMetadataTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONExportMetadataTest, MetadataPresent) +{ + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"metadata\""), std::string::npos); + EXPECT_NE(json.find("\"timestamp\""), std::string::npos); + EXPECT_NE(json.find("\"total_kernels\""), std::string::npos); +} + +TEST_F(JSONExportMetadataTest, CorrectKernelCount) +{ + const int num_kernels = 7; + for(int i = 0; i < num_kernels; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"total_kernels\": " + std::to_string(num_kernels)), std::string::npos); +} + +TEST_F(JSONExportMetadataTest, RegistryNameIncluded) +{ + Registry::instance().set_name("test_registry"); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"registry_name\""), std::string::npos); + EXPECT_NE(json.find("\"test_registry\""), std::string::npos); +} + +// ============================================================================= +// Export to File Tests +// ============================================================================= + +class JSONExportToFileTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + test_file_ = "/tmp/test_export_" + std::to_string(time(nullptr)) + ".json"; + } + + void TearDown() override + { + Registry::instance().clear(); + std::remove(test_file_.c_str()); + } + + std::string test_file_; +}; + +TEST_F(JSONExportToFileTest, ExportToFile) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + bool success = Registry::instance().export_json_to_file(test_file_, true); + EXPECT_TRUE(success); + + // Verify file exists + std::ifstream file(test_file_); + EXPECT_TRUE(file.good()); + + // Verify content + std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + EXPECT_NE(content.find("\"kernel\""), std::string::npos); +} + +TEST_F(JSONExportToFileTest, ExportToInvalidPath) +{ + bool success = Registry::instance().export_json_to_file("/invalid/path/file.json", true); + EXPECT_FALSE(success); +} + +// ============================================================================= +// Auto-Export Tests +// ============================================================================= + +class JSONAutoExportTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + Registry::instance().disable_auto_export(); + test_file_ = "/tmp/test_auto_export_" + std::to_string(time(nullptr)) + ".json"; + } + + void TearDown() override + { + Registry::instance().disable_auto_export(); + Registry::instance().clear(); + std::remove(test_file_.c_str()); + } + + std::string test_file_; +}; + +TEST_F(JSONAutoExportTest, EnableAutoExport) +{ + EXPECT_FALSE(Registry::instance().is_auto_export_enabled()); + + Registry::instance().enable_auto_export(test_file_, true, false); + + EXPECT_TRUE(Registry::instance().is_auto_export_enabled()); +} + +TEST_F(JSONAutoExportTest, DisableAutoExport) +{ + Registry::instance().enable_auto_export(test_file_, true, false); + EXPECT_TRUE(Registry::instance().is_auto_export_enabled()); + + Registry::instance().disable_auto_export(); + EXPECT_FALSE(Registry::instance().is_auto_export_enabled()); +} + +TEST_F(JSONAutoExportTest, AutoExportOnRegistration) +{ + // Enable auto-export with export_on_every_registration=true + Registry::instance().enable_auto_export(test_file_, true, false); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "auto_kernel"); + Registry::instance().register_kernel(kernel); + + // File might be created on registration or on exit depending on implementation + // Just verify auto-export is enabled + EXPECT_TRUE(Registry::instance().is_auto_export_enabled()); +} + +// ============================================================================= +// JSON Validity Tests +// ============================================================================= + +class JSONValidityTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } + + // Simple JSON syntax checker + bool isValidJSON(const std::string& json) + { + int braces = 0; + int brackets = 0; + bool in_string = false; + char prev = '\0'; + + for(char c : json) + { + if(c == '"' && prev != '\\') + { + in_string = !in_string; + } + + if(!in_string) + { + if(c == '{') + braces++; + else if(c == '}') + braces--; + else if(c == '[') + brackets++; + else if(c == ']') + brackets--; + } + + if(braces < 0 || brackets < 0) + return false; + prev = c; + } + + return braces == 0 && brackets == 0 && !in_string; + } +}; + +TEST_F(JSONValidityTest, EmptyRegistryProducesValidJSON) +{ + std::string json = Registry::instance().export_json(true); + EXPECT_TRUE(isValidJSON(json)); +} + +TEST_F(JSONValidityTest, SingleKernelProducesValidJSON) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + EXPECT_TRUE(isValidJSON(json)); +} + +TEST_F(JSONValidityTest, ManyKernelsProduceValidJSON) +{ + for(int i = 0; i < 50; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + std::string json = Registry::instance().export_json(true); + EXPECT_TRUE(isValidJSON(json)); +} + +TEST_F(JSONValidityTest, NoNullBytesInJSON) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + // Check for null bytes + EXPECT_EQ(json.find('\0'), std::string::npos); +} + +TEST_F(JSONValidityTest, NoPrintableGarbageInJSON) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + // All characters should be printable or whitespace + for(char c : json) + { + EXPECT_TRUE(std::isprint(c) || std::isspace(c)) + << "Non-printable character: " << static_cast(c); + } +} + +// ============================================================================= +// Kernel Details Tests +// ============================================================================= + +class JSONKernelDetailsTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONKernelDetailsTest, SignatureIncluded) +{ + auto key = make_test_key(256); + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"signature\""), std::string::npos); + EXPECT_NE(json.find("\"dtype_a\""), std::string::npos); + EXPECT_NE(json.find("\"fp16\""), std::string::npos); +} + +TEST_F(JSONKernelDetailsTest, AlgorithmIncluded) +{ + auto key = make_test_key(256, 256, 32); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"algorithm\""), std::string::npos); + EXPECT_NE(json.find("\"tile_shape\""), std::string::npos); +} + +TEST_F(JSONKernelDetailsTest, IdentifierIncluded) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "my_kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"identifier\""), std::string::npos); + EXPECT_NE(json.find("\"name\""), std::string::npos); + EXPECT_NE(json.find("\"my_kernel\""), std::string::npos); +} + +// ============================================================================= +// Multiple Registries Export Tests +// ============================================================================= + +class JSONMultipleRegistriesTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONMultipleRegistriesTest, DifferentRegistriesDifferentJSON) +{ + Registry reg1; + reg1.set_name("registry1"); + + Registry reg2; + reg2.set_name("registry2"); + + auto key1 = make_test_key(128); + auto key2 = make_test_key(256); + + reg1.register_kernel(std::make_shared(key1, "k1")); + reg2.register_kernel(std::make_shared(key2, "k2")); + + std::string json1 = reg1.export_json(true); + std::string json2 = reg2.export_json(true); + + EXPECT_NE(json1, json2); + + EXPECT_NE(json1.find("\"registry1\""), std::string::npos); + EXPECT_NE(json2.find("\"registry2\""), std::string::npos); + + EXPECT_NE(json1.find("\"k1\""), std::string::npos); + EXPECT_NE(json2.find("\"k2\""), std::string::npos); +} diff --git a/dispatcher/tests/test_kernel_key.cpp b/dispatcher/tests/test_kernel_key.cpp new file mode 100644 index 0000000000..b35641952a --- /dev/null +++ b/dispatcher/tests/test_kernel_key.cpp @@ -0,0 +1,147 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for KernelKey using Google Test + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +TEST(KernelKeyTest, Construction) +{ + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + + key.gfx_arch = "gfx942"; + + EXPECT_EQ(key.signature.dtype_a, DataType::FP16); + EXPECT_EQ(key.algorithm.tile_shape.m, 256); + EXPECT_EQ(key.gfx_arch, "gfx942"); +} + +TEST(KernelKeyTest, Equality) +{ + // Use helper function to ensure all fields are initialized + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + KernelKey key2 = make_test_key(256, 256, 32, "gfx942"); + + EXPECT_EQ(key1, key2); + EXPECT_FALSE(key1 != key2); + + // Change one value + KernelKey key3 = make_test_key(128, 256, 32, "gfx942"); + EXPECT_NE(key1, key3); + EXPECT_FALSE(key1 == key3); +} + +TEST(KernelKeyTest, EncodeIdentifier) +{ + KernelKey key; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.persistent = true; + key.algorithm.preshuffle = false; + key.signature.structured_sparsity = false; + + std::string id = key.encode_identifier(); + + // Check that identifier contains expected components + EXPECT_NE(id.find("256x256x32"), std::string::npos); // tile shape + EXPECT_NE(id.find("2x2x1"), std::string::npos); // wave shape + EXPECT_NE(id.find("32x32x16"), std::string::npos); // warp tile shape + EXPECT_NE(id.find("persist"), std::string::npos); // persistent flag +} + +TEST(KernelKeyTest, EncodeIdentifierWithFusion) +{ + KernelKey key; + key.signature.split_k = 1; + key.signature.elementwise_op = "Relu"; + key.signature.num_d_tensors = 2; + key.algorithm.tile_shape.m = 128; + key.algorithm.tile_shape.n = 128; + key.algorithm.tile_shape.k = 64; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 16; + key.algorithm.warp_tile_shape.n = 16; + key.algorithm.warp_tile_shape.k = 32; + key.algorithm.persistent = false; + key.signature.structured_sparsity = false; + + std::string id = key.encode_identifier(); + + // Check fusion-specific components + EXPECT_NE(id.find("Relu"), std::string::npos); + EXPECT_NE(id.find("_d2"), std::string::npos); + EXPECT_NE(id.find("nopers"), std::string::npos); +} + +TEST(KernelKeyTest, EncodeIdentifierWithSplitK) +{ + KernelKey key; + key.signature.split_k = 4; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.persistent = false; + key.signature.structured_sparsity = false; + + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("_splitk4"), std::string::npos); +} + +TEST(KernelKeyTest, EncodeIdentifierWithSparsity) +{ + KernelKey key; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = true; + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.persistent = false; + + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("_sparse"), std::string::npos); +} diff --git a/dispatcher/tests/test_kernel_key_extended.cpp b/dispatcher/tests/test_kernel_key_extended.cpp new file mode 100644 index 0000000000..1c6b5bcba0 --- /dev/null +++ b/dispatcher/tests/test_kernel_key_extended.cpp @@ -0,0 +1,453 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for KernelKey - covers all data types, layouts, pipelines + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "test_mock_kernel.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +// ============================================================================= +// DataType Tests +// ============================================================================= + +class DataTypeTest : public ::testing::Test +{ + protected: + void SetUp() override {} +}; + +TEST_F(DataTypeTest, AllDataTypesExist) +{ + // Every DataType should be accessible + std::vector all_types = {DataType::FP16, + DataType::BF16, + DataType::FP32, + DataType::FP64, + DataType::INT8, + DataType::INT4, + DataType::INT32, + DataType::FP8, + DataType::BF8, + DataType::UNKNOWN}; + + EXPECT_EQ(all_types.size(), 10); +} + +TEST_F(DataTypeTest, DataTypesAreDifferent) +{ + EXPECT_NE(DataType::FP16, DataType::BF16); + EXPECT_NE(DataType::FP16, DataType::FP32); + EXPECT_NE(DataType::INT8, DataType::INT4); +} + +// ============================================================================= +// LayoutTag Tests +// ============================================================================= + +class LayoutTagTest : public ::testing::Test +{ +}; + +TEST_F(LayoutTagTest, AllLayoutsExist) +{ + std::vector all_layouts = { + LayoutTag::RowMajor, LayoutTag::ColMajor, LayoutTag::PackedExternal}; + + EXPECT_EQ(all_layouts.size(), 3); +} + +TEST_F(LayoutTagTest, LayoutsAreDifferent) { EXPECT_NE(LayoutTag::RowMajor, LayoutTag::ColMajor); } + +// ============================================================================= +// Pipeline Tests +// ============================================================================= + +class PipelineTest : public ::testing::Test +{ +}; + +TEST_F(PipelineTest, AllPipelinesExist) +{ + std::vector all_pipelines = {Pipeline::Mem, + Pipeline::CompV1, + Pipeline::CompV2, + Pipeline::CompV3, + Pipeline::CompV4, + Pipeline::CompV5, + Pipeline::PreShuffleV1, + Pipeline::PreShuffleV2}; + + EXPECT_EQ(all_pipelines.size(), 8); +} + +TEST_F(PipelineTest, PipelinesAreDifferent) +{ + EXPECT_NE(Pipeline::Mem, Pipeline::CompV4); + EXPECT_NE(Pipeline::CompV3, Pipeline::CompV4); +} + +// ============================================================================= +// Scheduler Tests +// ============================================================================= + +class SchedulerTest : public ::testing::Test +{ +}; + +TEST_F(SchedulerTest, AllSchedulersExist) +{ + std::vector all_schedulers = { + Scheduler::Auto, Scheduler::Intrawave, Scheduler::Interwave}; + + EXPECT_EQ(all_schedulers.size(), 3); +} + +// ============================================================================= +// Epilogue Tests +// ============================================================================= + +class EpilogueTest : public ::testing::Test +{ +}; + +TEST_F(EpilogueTest, AllEpiloguesExist) +{ + std::vector all_epilogues = {Epilogue::None, + Epilogue::Default, + Epilogue::CShuffle, + Epilogue::Bias, + Epilogue::Activation, + Epilogue::BiasActivation}; + + EXPECT_EQ(all_epilogues.size(), 6); +} + +// ============================================================================= +// KernelKey::Signature Tests +// ============================================================================= + +class SignatureTest : public ::testing::Test +{ + protected: + KernelKey::Signature CreateDefaultSignature() + { + KernelKey::Signature sig; + sig.dtype_a = DataType::FP16; + sig.dtype_b = DataType::FP16; + sig.dtype_c = DataType::FP16; + sig.dtype_acc = DataType::FP32; + sig.layout_a = LayoutTag::RowMajor; + sig.layout_b = LayoutTag::ColMajor; + sig.layout_c = LayoutTag::RowMajor; + sig.transpose_a = false; + sig.transpose_b = false; + sig.grouped = false; + sig.split_k = 1; + sig.elementwise_op = "PassThrough"; + sig.num_d_tensors = 0; + sig.structured_sparsity = false; + return sig; + } +}; + +TEST_F(SignatureTest, DefaultValuesAreReasonable) +{ + KernelKey::Signature sig = CreateDefaultSignature(); + EXPECT_EQ(sig.split_k, 1); + EXPECT_FALSE(sig.grouped); + EXPECT_FALSE(sig.structured_sparsity); +} + +TEST_F(SignatureTest, AllDataTypeCombinations) +{ + // Test various data type combinations that should be valid + std::vector> valid_combos = { + {DataType::FP16, DataType::FP16, DataType::FP16, DataType::FP32}, + {DataType::BF16, DataType::BF16, DataType::BF16, DataType::FP32}, + {DataType::FP32, DataType::FP32, DataType::FP32, DataType::FP32}, + {DataType::INT8, DataType::INT8, DataType::INT8, DataType::INT32}, + }; + + for(const auto& [a, b, c, acc] : valid_combos) + { + KernelKey::Signature sig; + sig.dtype_a = a; + sig.dtype_b = b; + sig.dtype_c = c; + sig.dtype_acc = acc; + + EXPECT_EQ(sig.dtype_a, a); + EXPECT_EQ(sig.dtype_b, b); + EXPECT_EQ(sig.dtype_c, c); + EXPECT_EQ(sig.dtype_acc, acc); + } +} + +TEST_F(SignatureTest, AllLayoutCombinations) +{ + std::vector layout_codes = { + "rrr", "rcr", "crr", "ccr", "rrc", "rcc", "crc", "ccc"}; + + for(const std::string& code : layout_codes) + { + KernelKey::Signature sig = CreateDefaultSignature(); + sig.layout_a = (code[0] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor; + sig.layout_b = (code[1] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor; + sig.layout_c = (code[2] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor; + + // Just verify assignment works + EXPECT_TRUE(sig.layout_a == LayoutTag::RowMajor || sig.layout_a == LayoutTag::ColMajor); + } +} + +TEST_F(SignatureTest, SplitKValues) +{ + KernelKey::Signature sig = CreateDefaultSignature(); + + std::vector valid_split_k = {1, 2, 4, 8, 16}; + for(auto sk : valid_split_k) + { + sig.split_k = sk; + EXPECT_EQ(sig.split_k, sk); + } +} + +// ============================================================================= +// KernelKey::Algorithm Tests +// ============================================================================= + +class AlgorithmTest : public ::testing::Test +{ + protected: + KernelKey::Algorithm CreateDefaultAlgorithm() + { + KernelKey::Algorithm algo; + algo.tile_shape = {256, 256, 32}; + algo.wave_shape = {2, 2, 1}; + algo.warp_tile_shape = {32, 32, 16}; + algo.pipeline = Pipeline::CompV4; + algo.scheduler = Scheduler::Intrawave; + algo.epilogue = Epilogue::CShuffle; + algo.block_size = 256; + algo.double_buffer = true; + algo.persistent = false; + algo.preshuffle = false; + algo.transpose_c = false; + algo.num_wave_groups = 1; + return algo; + } +}; + +TEST_F(AlgorithmTest, CommonTileShapes) +{ + std::vector> valid_tiles = { + {64, 64, 32}, + {128, 128, 32}, + {128, 128, 64}, + {256, 256, 32}, + {256, 256, 64}, + {256, 128, 32}, + {128, 256, 32}, + }; + + for(const auto& [m, n, k] : valid_tiles) + { + KernelKey::Algorithm algo = CreateDefaultAlgorithm(); + algo.tile_shape = {static_cast(m), + static_cast(n), + static_cast(k)}; + + EXPECT_EQ(algo.tile_shape.m, m); + EXPECT_EQ(algo.tile_shape.n, n); + EXPECT_EQ(algo.tile_shape.k, k); + } +} + +TEST_F(AlgorithmTest, CommonWarpConfigs) +{ + std::vector> valid_warps = { + {1, 4, 1}, + {2, 2, 1}, + {4, 1, 1}, + {1, 2, 1}, + {2, 1, 1}, + }; + + for(const auto& [m, n, k] : valid_warps) + { + KernelKey::Algorithm algo = CreateDefaultAlgorithm(); + algo.wave_shape = {static_cast(m), + static_cast(n), + static_cast(k)}; + + EXPECT_EQ(algo.wave_shape.m, m); + EXPECT_EQ(algo.wave_shape.n, n); + EXPECT_EQ(algo.wave_shape.k, k); + } +} + +TEST_F(AlgorithmTest, AllPipelines) +{ + KernelKey::Algorithm algo = CreateDefaultAlgorithm(); + + std::vector pipelines = {Pipeline::Mem, + Pipeline::CompV3, + Pipeline::CompV4, + Pipeline::PreShuffleV1, + Pipeline::PreShuffleV2}; + + for(Pipeline p : pipelines) + { + algo.pipeline = p; + EXPECT_EQ(algo.pipeline, p); + } +} + +// ============================================================================= +// KernelKey Identifier Encoding Tests +// ============================================================================= + +class IdentifierEncodingTest : public ::testing::Test +{ +}; + +TEST_F(IdentifierEncodingTest, UniqueIdentifiersForDifferentConfigs) +{ + std::set identifiers; + + // Generate multiple configurations + for(int tile_m : {128, 256}) + { + for(int wave_m : {1, 2, 4}) + { + for(bool persistent : {true, false}) + { + KernelKey key = make_test_key(tile_m); + key.algorithm.wave_shape.m = wave_m; + key.algorithm.persistent = persistent; + + std::string id = key.encode_identifier(); + EXPECT_TRUE(identifiers.find(id) == identifiers.end()) + << "Duplicate identifier: " << id; + identifiers.insert(id); + } + } + } + + // Should have generated 2 * 3 * 2 = 12 unique identifiers + EXPECT_EQ(identifiers.size(), 12); +} + +TEST_F(IdentifierEncodingTest, IdentifierContainsTileShape) +{ + KernelKey key = make_test_key(256, 128, 64); + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("256x128x64"), std::string::npos) + << "Identifier should contain tile shape: " << id; +} + +TEST_F(IdentifierEncodingTest, IdentifierContainsWarpConfig) +{ + KernelKey key = make_test_key(256); + key.algorithm.wave_shape = {4, 2, 1}; + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("4x2x1"), std::string::npos) + << "Identifier should contain warp config: " << id; +} + +TEST_F(IdentifierEncodingTest, IdentifierReflectsPersistence) +{ + KernelKey persistent_key = make_test_key(256); + persistent_key.algorithm.persistent = true; + + KernelKey non_persistent_key = make_test_key(256); + non_persistent_key.algorithm.persistent = false; + + std::string persistent_id = persistent_key.encode_identifier(); + std::string non_persistent_id = non_persistent_key.encode_identifier(); + + EXPECT_NE(persistent_id, non_persistent_id); + EXPECT_NE(persistent_id.find("persist"), std::string::npos); + EXPECT_NE(non_persistent_id.find("nopers"), std::string::npos); +} + +// ============================================================================= +// KernelKey Equality Tests +// ============================================================================= + +class KeyEqualityTest : public ::testing::Test +{ +}; + +TEST_F(KeyEqualityTest, IdenticalKeysAreEqual) +{ + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + KernelKey key2 = make_test_key(256, 256, 32, "gfx942"); + + EXPECT_EQ(key1, key2); + EXPECT_FALSE(key1 != key2); +} + +TEST_F(KeyEqualityTest, DifferentTileShapesNotEqual) +{ + KernelKey key1 = make_test_key(256, 256, 32); + KernelKey key2 = make_test_key(128, 128, 32); + + EXPECT_NE(key1, key2); +} + +TEST_F(KeyEqualityTest, DifferentDataTypesNotEqual) +{ + KernelKey key1 = make_test_key(256); + KernelKey key2 = make_test_key(256); + key2.signature.dtype_a = DataType::BF16; + + EXPECT_NE(key1, key2); +} + +TEST_F(KeyEqualityTest, DifferentLayoutsNotEqual) +{ + KernelKey key1 = make_test_key(256); + KernelKey key2 = make_test_key(256); + key2.signature.layout_a = LayoutTag::ColMajor; + + EXPECT_NE(key1, key2); +} + +TEST_F(KeyEqualityTest, DifferentGfxArchNotEqual) +{ + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + KernelKey key2 = make_test_key(256, 256, 32, "gfx90a"); + + EXPECT_NE(key1, key2); +} + +// ============================================================================= +// ElementwiseOps Tests +// ============================================================================= + +class ElementwiseOpsTest : public ::testing::Test +{ +}; + +TEST_F(ElementwiseOpsTest, CanUseInKernelKey) +{ + KernelKey key = make_test_key(256); + + key.signature.elementwise_op = "Relu"; + EXPECT_EQ(key.signature.elementwise_op, "Relu"); + + key.signature.elementwise_op = "Gelu"; + EXPECT_EQ(key.signature.elementwise_op, "Gelu"); + + key.signature.elementwise_op = "PassThrough"; + EXPECT_EQ(key.signature.elementwise_op, "PassThrough"); +} diff --git a/dispatcher/tests/test_minimal.cpp b/dispatcher/tests/test_minimal.cpp new file mode 100644 index 0000000000..22efc2524c --- /dev/null +++ b/dispatcher/tests/test_minimal.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Minimal test: Verify dispatcher can select and run a kernel +#include +#include +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +int main() +{ + std::cout << "Minimal Dispatcher Test\n"; + std::cout << "=======================\n\n"; + + // Create a mock kernel for testing + KernelKey key = make_test_key(128, 128, 64, "gfx942"); + auto kernel = std::make_shared(key, "test_kernel_128x128x64", true); + + // Register kernel + Registry::instance().clear(); + Registry::instance().register_kernel(kernel); + + std::cout << "OK Registered kernel: " << kernel->get_name() << "\n"; + + // Create dispatcher and problem + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + std::cout << "OK Created problem: M=" << problem.M << " N=" << problem.N << " K=" << problem.K + << "\n"; + + // Select kernel + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "[FAIL] Failed to select kernel\n"; + return 1; + } + + std::cout << "OK Selected kernel: " << selected->get_name() << "\n"; + + // Mock execution (no actual GPU computation in mock kernel) + void* a_ptr = nullptr; + void* b_ptr = nullptr; + void* c_ptr = nullptr; + + float time = dispatcher.run(a_ptr, b_ptr, c_ptr, problem); + + std::cout << "OK Executed kernel: " << time << " ms\n"; + std::cout << "\n[OK] Minimal test passed!\n"; + + return 0; +} diff --git a/dispatcher/tests/test_mock_kernel.cpp b/dispatcher/tests/test_mock_kernel.cpp new file mode 100644 index 0000000000..fd8f3f4baa --- /dev/null +++ b/dispatcher/tests/test_mock_kernel.cpp @@ -0,0 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mock_kernel.hpp" + +// Empty file - implementation is in header diff --git a/dispatcher/tests/test_mock_kernel.hpp b/dispatcher/tests/test_mock_kernel.hpp new file mode 100644 index 0000000000..7d511719a8 --- /dev/null +++ b/dispatcher/tests/test_mock_kernel.hpp @@ -0,0 +1,134 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include + +namespace ck_tile { +namespace dispatcher { +namespace test { + +/// Mock kernel instance for testing dispatcher functionality +/// Supports configurable behavior for testing different scenarios +class MockKernelInstance : public KernelInstance +{ + public: + /// Constructor + /// @param key Kernel configuration key + /// @param name Human-readable kernel name + /// @param supports_all Whether this kernel supports all problems (default: true) + explicit MockKernelInstance(const KernelKey& key, + const std::string& name, + bool supports_all = true) + : key_(key), name_(name), supports_all_(supports_all), execution_count_(0) + { + } + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + if(supports_all_) + { + return problem.is_valid(); + } + // For testing: only support problems where M/N/K are divisible by tile sizes + return problem.is_valid() && (problem.M % key_.algorithm.tile_shape.m == 0) && + (problem.N % key_.algorithm.tile_shape.n == 0) && + (problem.K % key_.algorithm.tile_shape.k == 0); + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + execution_count_++; + // Simulate execution time (1ms for testing) + return 1.0f; + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + // Mock validation always passes + return true; + } + + /// Get execution count (for testing) + int get_execution_count() const { return execution_count_; } + + /// Reset execution count + void reset_execution_count() { execution_count_ = 0; } + + /// Set whether this kernel supports all problems + void set_supports_all(bool supports_all) { supports_all_ = supports_all; } + + private: + KernelKey key_; + std::string name_; + bool supports_all_; + mutable int execution_count_; +}; + +/// Helper function to create a test kernel key +inline KernelKey make_test_key(std::uint16_t tile_m = 256, + std::uint16_t tile_n = 256, + std::uint16_t tile_k = 32, + const std::string& gfx_arch = "gfx942") +{ + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape.m = tile_m; + key.algorithm.tile_shape.n = tile_n; + key.algorithm.tile_shape.k = tile_k; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + + key.gfx_arch = gfx_arch; + + return key; +} + +} // namespace test +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/tests/test_problem.cpp b/dispatcher/tests/test_problem.cpp new file mode 100644 index 0000000000..7d5500e320 --- /dev/null +++ b/dispatcher/tests/test_problem.cpp @@ -0,0 +1,96 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for Problem using Google Test + +#include "ck_tile/dispatcher/problem.hpp" +#include + +using namespace ck_tile::dispatcher; + +TEST(ProblemTest, DefaultConstruction) +{ + Problem p; + EXPECT_EQ(p.M, 0); + EXPECT_EQ(p.N, 0); + EXPECT_EQ(p.K, 0); + EXPECT_EQ(p.k_batch, 1); + EXPECT_FALSE(p.is_valid()); +} + +TEST(ProblemTest, ConstructorWithDimensions) +{ + Problem p(1024, 1024, 1024); + EXPECT_EQ(p.M, 1024); + EXPECT_EQ(p.N, 1024); + EXPECT_EQ(p.K, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST(ProblemTest, Validation) +{ + Problem p; + + // Invalid: all zeros + p.M = 0; + p.N = 0; + p.K = 0; + EXPECT_FALSE(p.is_valid()); + + // Invalid: negative + p.M = -1; + p.N = 1024; + p.K = 1024; + EXPECT_FALSE(p.is_valid()); + + // Invalid: zero K + p.M = 1024; + p.N = 1024; + p.K = 0; + EXPECT_FALSE(p.is_valid()); + + // Valid + p.M = 1024; + p.N = 1024; + p.K = 1024; + EXPECT_TRUE(p.is_valid()); + + // Invalid k_batch + p.k_batch = 0; + EXPECT_FALSE(p.is_valid()); + + p.k_batch = 1; + EXPECT_TRUE(p.is_valid()); +} + +TEST(ProblemTest, NumOps) +{ + Problem p(100, 200, 300); + + // 2 * M * N * K (multiply-add = 2 ops) + std::int64_t expected = 2 * 100 * 200 * 300; + EXPECT_EQ(p.num_ops(), expected); +} + +TEST(ProblemTest, Configuration) +{ + Problem p(1024, 1024, 1024); + + // Set preferences + p.prefer_persistent = true; + p.enable_validation = true; + p.smem_budget = 65536; + p.k_batch = 2; + + EXPECT_TRUE(p.prefer_persistent); + EXPECT_TRUE(p.enable_validation); + EXPECT_EQ(p.smem_budget, 65536); + EXPECT_EQ(p.k_batch, 2); +} + +TEST(ProblemTest, LargeDimensions) +{ + Problem p(1024, 1024, 1024); // Use smaller but still large dimensions + EXPECT_TRUE(p.is_valid()); + EXPECT_GT(p.num_ops(), 0); +} diff --git a/dispatcher/tests/test_problem_extended.cpp b/dispatcher/tests/test_problem_extended.cpp new file mode 100644 index 0000000000..21ea545292 --- /dev/null +++ b/dispatcher/tests/test_problem_extended.cpp @@ -0,0 +1,457 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for Problem - covers dimension inference, validation, edge cases + +#include "ck_tile/dispatcher/problem.hpp" +#include +#include + +using namespace ck_tile::dispatcher; + +// ============================================================================= +// Dimension Inference Tests +// ============================================================================= + +class ProblemDimensionInferenceTest : public ::testing::Test +{ +}; + +TEST_F(ProblemDimensionInferenceTest, FromAB_Basic) +{ + // A: M×K (1024×512), B: K×N (512×2048) + auto problem = Problem::from_ab(1024, 512, 512, 2048); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemDimensionInferenceTest, FromDimensions_Valid) +{ + // A: 1024×512, B: 512×2048, C: 1024×2048 + auto problem = Problem::from_dimensions(1024, 512, 512, 2048, 1024, 2048); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemDimensionInferenceTest, FromShapes_WithC) +{ + TensorShape A{1024, 512, false}; + TensorShape B{512, 2048, false}; + TensorShape C{1024, 2048, false}; + + auto problem = Problem::from_shapes(A, B, C); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA) +{ + // A stored as K×M (transposed) + TensorShape A{512, 1024, true}; + TensorShape B{512, 2048, false}; + TensorShape C{1024, 2048, false}; + + auto problem = Problem::from_shapes(A, B, C); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); +} + +TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedB) +{ + TensorShape A{1024, 512, false}; + // B stored as N×K (transposed) + TensorShape B{2048, 512, true}; + TensorShape C{1024, 2048, false}; + + auto problem = Problem::from_shapes(A, B, C); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); +} + +// ============================================================================= +// Validation Tests +// ============================================================================= + +class ProblemValidationTest : public ::testing::Test +{ +}; + +TEST_F(ProblemValidationTest, ValidProblem) +{ + Problem p(1024, 1024, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroM) +{ + Problem p(0, 1024, 1024); + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroN) +{ + Problem p(1024, 0, 1024); + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroK) +{ + Problem p(1024, 1024, 0); + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, NegativeM) +{ + Problem p; + p.M = -1; + p.N = 1024; + p.K = 1024; + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroKBatch) +{ + Problem p(1024, 1024, 1024); + p.k_batch = 0; + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ValidKBatch) +{ + Problem p(1024, 1024, 1024); + p.k_batch = 4; + EXPECT_TRUE(p.is_valid()); +} + +// ============================================================================= +// num_ops Tests +// ============================================================================= + +class ProblemNumOpsTest : public ::testing::Test +{ +}; + +TEST_F(ProblemNumOpsTest, SmallProblem) +{ + Problem p(10, 20, 30); + // 2 * M * N * K = 2 * 10 * 20 * 30 = 12000 + EXPECT_EQ(p.num_ops(), 12000); +} + +TEST_F(ProblemNumOpsTest, SymmetricProblem) +{ + Problem p(1024, 1024, 1024); + // 2 * 1024^3 = 2,147,483,648 + EXPECT_EQ(p.num_ops(), 2LL * 1024 * 1024 * 1024); +} + +TEST_F(ProblemNumOpsTest, AsymmetricProblem) +{ + Problem p(512, 2048, 256); + EXPECT_EQ(p.num_ops(), 2LL * 512 * 2048 * 256); +} + +TEST_F(ProblemNumOpsTest, LargeProblem) +{ + Problem p(4096, 4096, 4096); + std::int64_t expected = 2LL * 4096 * 4096 * 4096; + EXPECT_EQ(p.num_ops(), expected); + EXPECT_GT(p.num_ops(), 0); // No overflow +} + +// ============================================================================= +// Edge Cases +// ============================================================================= + +class ProblemEdgeCasesTest : public ::testing::Test +{ +}; + +TEST_F(ProblemEdgeCasesTest, MinimumValidSize) +{ + Problem p(1, 1, 1); + EXPECT_TRUE(p.is_valid()); + EXPECT_EQ(p.num_ops(), 2); +} + +TEST_F(ProblemEdgeCasesTest, NonSquare_TallMatrix) +{ + Problem p(8192, 64, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, NonSquare_WideMatrix) +{ + Problem p(64, 8192, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, NonSquare_DeepK) +{ + Problem p(1024, 1024, 8192); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, SmallK) +{ + Problem p(1024, 1024, 16); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, NonPowerOf2Dimensions) +{ + Problem p(1000, 2000, 300); + EXPECT_TRUE(p.is_valid()); + EXPECT_EQ(p.num_ops(), 2LL * 1000 * 2000 * 300); +} + +TEST_F(ProblemEdgeCasesTest, PrimeDimensions) +{ + Problem p(997, 1009, 1013); // All prime numbers + EXPECT_TRUE(p.is_valid()); +} + +// ============================================================================= +// Configuration Tests +// ============================================================================= + +class ProblemConfigurationTest : public ::testing::Test +{ +}; + +TEST_F(ProblemConfigurationTest, DefaultConfiguration) +{ + Problem p(1024, 1024, 1024); + + EXPECT_FALSE(p.prefer_persistent); + EXPECT_FALSE(p.enable_validation); + EXPECT_EQ(p.smem_budget, 0); + EXPECT_EQ(p.k_batch, 1); +} + +TEST_F(ProblemConfigurationTest, SetPersistentPreference) +{ + Problem p(1024, 1024, 1024); + p.prefer_persistent = true; + + EXPECT_TRUE(p.prefer_persistent); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemConfigurationTest, SetSmemBudget) +{ + Problem p(1024, 1024, 1024); + p.smem_budget = 65536; // 64KB + + EXPECT_EQ(p.smem_budget, 65536); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemConfigurationTest, SetKBatch) +{ + Problem p(1024, 1024, 1024); + + for(int kb : {1, 2, 4, 8, 16}) + { + p.k_batch = kb; + EXPECT_EQ(p.k_batch, kb); + EXPECT_TRUE(p.is_valid()); + } +} + +// ============================================================================= +// Copy and Assignment Tests +// ============================================================================= + +class ProblemCopyTest : public ::testing::Test +{ +}; + +TEST_F(ProblemCopyTest, CopyConstruction) +{ + Problem p1(1024, 2048, 512); + p1.prefer_persistent = true; + p1.k_batch = 4; + + Problem p2(p1); + + EXPECT_EQ(p2.M, 1024); + EXPECT_EQ(p2.N, 2048); + EXPECT_EQ(p2.K, 512); + EXPECT_TRUE(p2.prefer_persistent); + EXPECT_EQ(p2.k_batch, 4); +} + +TEST_F(ProblemCopyTest, Assignment) +{ + Problem p1(1024, 2048, 512); + Problem p2(256, 256, 256); + + p2 = p1; + + EXPECT_EQ(p2.M, 1024); + EXPECT_EQ(p2.N, 2048); + EXPECT_EQ(p2.K, 512); +} + +// ============================================================================= +// Builder Tests +// ============================================================================= + +class ProblemBuilderTest : public ::testing::Test +{ +}; + +TEST_F(ProblemBuilderTest, BasicBuild) +{ + auto problem = ProblemBuilder().dimensions(1024, 2048, 512).build(); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemBuilderTest, WithSplitK) +{ + auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).split_k(4).build(); + + EXPECT_EQ(problem.k_batch, 4); +} + +TEST_F(ProblemBuilderTest, WithPersistent) +{ + auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).persistent(true).build(); + + EXPECT_TRUE(problem.prefer_persistent); +} + +TEST_F(ProblemBuilderTest, WithSmemBudget) +{ + auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).smem_budget(65536).build(); + + EXPECT_EQ(problem.smem_budget, 65536); +} + +TEST_F(ProblemBuilderTest, ChainedConfiguration) +{ + auto problem = ProblemBuilder() + .dimensions(2048, 2048, 1024) + .split_k(2) + .persistent(true) + .smem_budget(32768) + .validate(true) + .build(); + + EXPECT_EQ(problem.M, 2048); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 1024); + EXPECT_EQ(problem.k_batch, 2); + EXPECT_TRUE(problem.prefer_persistent); + EXPECT_EQ(problem.smem_budget, 32768); + EXPECT_TRUE(problem.enable_validation); +} + +TEST_F(ProblemBuilderTest, FromAB) +{ + auto problem = ProblemBuilder().from_ab(1024, 512, 512, 2048).build(); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); +} + +// ============================================================================= +// Dimension Mismatch Error Tests +// ============================================================================= + +class ProblemDimensionErrorTest : public ::testing::Test +{ +}; + +TEST_F(ProblemDimensionErrorTest, KMismatchThrows) +{ + EXPECT_THROW((void)Problem::from_ab(1024, 512, 256, 2048), // K mismatch: 512 vs 256 + std::invalid_argument); +} + +TEST_F(ProblemDimensionErrorTest, MDimensionMismatchThrows) +{ + TensorShape A{1024, 512, false}; + TensorShape B{512, 2048, false}; + TensorShape C{512, 2048, false}; // M mismatch: A says M=1024, C says M=512 + + EXPECT_THROW((void)Problem::from_shapes(A, B, C), std::invalid_argument); +} + +TEST_F(ProblemDimensionErrorTest, NDimensionMismatchThrows) +{ + TensorShape A{1024, 512, false}; + TensorShape B{512, 2048, false}; + TensorShape C{1024, 1024, false}; // N mismatch: B says N=2048, C says N=1024 + + EXPECT_THROW((void)Problem::from_shapes(A, B, C), std::invalid_argument); +} + +// ============================================================================= +// Validate Sizes Tests +// ============================================================================= + +class ProblemValidateSizesTest : public ::testing::Test +{ +}; + +TEST_F(ProblemValidateSizesTest, CorrectSizes) +{ + Problem p(1024, 2048, 512); + + // This should not throw + EXPECT_NO_THROW(p.validate_sizes(1024 * 512, // A size + 512 * 2048, // B size + 1024 * 2048 // C size + )); +} + +TEST_F(ProblemValidateSizesTest, WrongASizeThrows) +{ + Problem p(1024, 2048, 512); + + EXPECT_THROW(p.validate_sizes(1024 * 256, // Wrong A size + 512 * 2048, + 1024 * 2048), + std::invalid_argument); +} + +TEST_F(ProblemValidateSizesTest, WrongBSizeThrows) +{ + Problem p(1024, 2048, 512); + + EXPECT_THROW(p.validate_sizes(1024 * 512, + 256 * 2048, // Wrong B size + 1024 * 2048), + std::invalid_argument); +} + +TEST_F(ProblemValidateSizesTest, WrongCSizeThrows) +{ + Problem p(1024, 2048, 512); + + EXPECT_THROW(p.validate_sizes(1024 * 512, + 512 * 2048, + 512 * 1024 // Wrong C size + ), + std::invalid_argument); +} diff --git a/dispatcher/tests/test_real_kernel_correctness.cpp b/dispatcher/tests/test_real_kernel_correctness.cpp new file mode 100644 index 0000000000..e753f04e19 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_correctness.cpp @@ -0,0 +1,232 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Correctness test with real GPU kernel + * Validates GPU results against CPU reference implementation + */ + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +// CPU reference GEMM +// A: RowMajor (M x K) - A[m,k] = A[m*K + k] +// B: ColumnMajor (K x N) - B[k,n] = B[k + n*K] +// C: RowMajor (M x N) - C[m,n] = C[m*N + n] +template +void cpu_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + // A is row-major: A[m,k] = A[m*K + k] + // B is column-major: B[k,n] = B[k + n*K] + acc += float(A[m * K + k]) * float(B[k + n * K]); + } + C[m * N + n] = T(acc); + } + } +} + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Correctness Test - Real GPU Kernel\n"; + std::cout << "=======================================\n\n"; + + std::cout << "Kernel: " << KERNEL_NAME << "\n\n"; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + + // Test with random matrices + const int M = 256; + const int N = 256; + const int K = 256; + + std::cout << "Test configuration:\n"; + std::cout << " Problem: M=" << M << " N=" << N << " K=" << K << "\n"; + std::cout << " Method: Random matrices vs CPU reference\n\n"; + + // Random number generation + std::mt19937 rng(42); // Fixed seed for reproducibility + std::uniform_real_distribution dist(-1.0f, 1.0f); + + std::vector A_host(M * K); + std::vector B_host(K * N); + std::vector C_gpu(M * N); + std::vector C_cpu(M * N); + + // Initialize with random values + std::cout << "Initializing random matrices...\n"; + for(int i = 0; i < M * K; i++) + { + A_host[i] = ADataType(dist(rng)); + } + for(int i = 0; i < K * N; i++) + { + B_host[i] = BDataType(dist(rng)); + } + + // GPU execution + std::cout << "Executing on GPU...\n"; + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Problem problem(M, N, K); + float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + std::cout << "OK GPU execution complete: " << gpu_time << " ms\n"; + + double flops = 2.0 * M * N * K; + double tflops = (flops / (gpu_time * 1e-3)) / 1e12; + std::cout << "OK GPU performance: " << tflops << " TFLOPS\n\n"; + + // CPU reference + std::cout << "Computing CPU reference...\n"; + cpu_gemm(A_host, B_host, C_cpu, M, N, K); + std::cout << "OK CPU reference complete\n\n"; + + // Validation + std::cout << "Validating results...\n"; + + int num_correct = 0; + float max_rel_error = 0.0f; + float max_abs_error = 0.0f; + const float tolerance = 0.02f; // 2% for FP16 + + for(int i = 0; i < M * N; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + + float abs_error = std::abs(gpu_val - cpu_val); + float rel_error = abs_error / (std::abs(cpu_val) + 1e-5f); + + max_abs_error = std::max(max_abs_error, abs_error); + max_rel_error = std::max(max_rel_error, rel_error); + + if(rel_error < tolerance) + { + num_correct++; + } + } + + float accuracy = 100.0f * num_correct / (M * N); + + std::cout << "\nValidation Results:\n"; + std::cout << " Correct elements: " << num_correct << "/" << M * N << "\n"; + std::cout << " Accuracy: " << accuracy << "%\n"; + std::cout << " Max absolute error: " << max_abs_error << "\n"; + std::cout << " Max relative error: " << max_rel_error << "\n"; + std::cout << " Tolerance: " << tolerance << " (2%)\n\n"; + + // Show sample comparisons + std::cout << "Sample results (first 5 elements):\n"; + std::cout << " Index | GPU Result | CPU Result | Error\n"; + std::cout << " ------|------------|------------|-------\n"; + + for(int i = 0; i < 5; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + float error = std::abs(gpu_val - cpu_val); + printf(" %-5d | %10.4f | %10.4f | %.4f\n", i, gpu_val, cpu_val, error); + } + std::cout << "\n"; + + // Cleanup + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + if(accuracy > 99.0f) + { + std::cout << "[OK] CORRECTNESS TEST PASSED\n"; + std::cout << " GPU results match CPU reference within tolerance\n"; + return 0; + } + else + { + std::cout << "[FAIL] CORRECTNESS TEST FAILED\n"; + std::cout << " Accuracy too low: " << accuracy << "%\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_real_kernel_multi_size.cpp b/dispatcher/tests/test_real_kernel_multi_size.cpp new file mode 100644 index 0000000000..f23f684631 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_multi_size.cpp @@ -0,0 +1,213 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Multi-size real kernel test: Test multiple problem sizes with real GPU kernel + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +struct TestResult +{ + int M, N, K; + float time_ms; + double tflops; + int correct; + int total; + bool passed; +}; + +TestResult run_test(Dispatcher& dispatcher, int M, int N, int K) +{ + TestResult result = {M, N, K, 0.0f, 0.0, 0, M * N, false}; + + // Allocate and prepare data + std::vector A_host(M * K); + std::vector B_host(K * N); + std::vector C_gpu(M * N); + + // Initialize: A=1, B=1, expected C=K + for(int i = 0; i < M * K; i++) + A_host[i] = ADataType(1.0f); + for(int i = 0; i < K * N; i++) + B_host[i] = BDataType(1.0f); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + // Execute + Problem problem(M, N, K); + result.time_ms = dispatcher.run(A_dev, B_dev, C_dev, problem); + + // Calculate performance + double flops = 2.0 * M * N * K; + result.tflops = (flops / (result.time_ms * 1e-3)) / 1e12; + + // Copy result and validate + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C_gpu[i]) - float(K)) < 1.0f) + { + result.correct++; + } + } + + result.passed = (result.correct == result.total); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + return result; +} + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Multi-Size Real Kernel Test\n"; + std::cout << "=======================================\n\n"; + + std::cout << "Using kernel: " << KERNEL_NAME << "\n\n"; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + + std::cout << "Running tests on multiple problem sizes...\n"; + std::cout << "===========================================\n\n"; + + // Test various sizes (all multiples of tile size) + std::vector> test_sizes = { + {128, 128, 128}, // Small + {256, 256, 256}, // Medium + {512, 512, 512}, // Large + {1024, 1024, 1024}, // Very large + {128, 512, 256}, // Non-square + {512, 128, 384}, // Non-square + }; + + std::vector results; + int num_passed = 0; + + for(const auto& [M, N, K] : test_sizes) + { + std::cout << "Testing M=" << M << " N=" << N << " K=" << K << "...\n"; + + auto result = run_test(dispatcher, M, N, K); + results.push_back(result); + + std::cout << " Time: " << result.time_ms << " ms\n"; + std::cout << " Performance: " << result.tflops << " TFLOPS\n"; + std::cout << " Accuracy: " << (100.0f * result.correct / result.total) << "%\n"; + std::cout << " Status: " << (result.passed ? "[OK] PASS" : "[FAIL] FAIL") << "\n\n"; + + if(result.passed) + num_passed++; + } + + // Summary + std::cout << "===========================================\n"; + std::cout << "Summary\n"; + std::cout << "===========================================\n\n"; + + std::cout << "Results by size:\n"; + std::cout << " Size | Time (ms) | TFLOPS | Accuracy | Status\n"; + std::cout << " ---------------|-----------|--------|----------|--------\n"; + + for(const auto& r : results) + { + char size_str[32]; + snprintf(size_str, sizeof(size_str), "%4d×%4d×%4d", r.M, r.N, r.K); + + printf(" %-14s | %9.4f | %6.2f | %7.2f%% | %s\n", + size_str, + r.time_ms, + r.tflops, + 100.0f * r.correct / r.total, + r.passed ? "[OK]" : "[FAIL]"); + } + + std::cout << "\n"; + std::cout << "Tests passed: " << num_passed << "/" << results.size() << "\n"; + + if(num_passed == results.size()) + { + std::cout << "\n[OK] ALL TESTS PASSED\n"; + return 0; + } + else + { + std::cout << "\n[FAIL] SOME TESTS FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_real_kernel_performance.cpp b/dispatcher/tests/test_real_kernel_performance.cpp new file mode 100644 index 0000000000..ff3d635968 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_performance.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Performance test with real GPU kernel + * Measures and reports detailed performance metrics + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Performance Test - Real GPU Kernel\n"; + std::cout << "=======================================\n\n"; + + std::cout << "Kernel: " << KERNEL_NAME << "\n"; + std::cout << "Device: AMD Instinct MI325X (gfx942)\n\n"; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + + // Performance benchmark sizes + std::vector> benchmarks = { + {128, 128, 128, "Tiny"}, + {256, 256, 256, "Small"}, + {512, 512, 512, "Medium"}, + {1024, 1024, 1024, "Large"}, + {2048, 2048, 2048, "Very Large"}, + }; + + std::cout << "Performance Benchmark Results\n"; + std::cout << "=============================\n\n"; + + std::cout << " Size | Time (ms) | TFLOPS | BW (GB/s) | Status\n"; + std::cout << " ----------|-----------|--------|-----------|--------\n"; + + bool all_passed = true; + + for(const auto& [M, N, K, label] : benchmarks) + { + // Prepare data + std::vector A_host(M * K, ADataType(1.0f)); + std::vector B_host(K * N, BDataType(1.0f)); + std::vector C_gpu(M * N); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK( + hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + // Execute + Problem problem(M, N, K); + float time_ms = dispatcher.run(A_dev, B_dev, C_dev, problem); + + // Calculate metrics + double flops = 2.0 * M * N * K; + double tflops = (flops / (time_ms * 1e-3)) / 1e12; + + // Bandwidth (A + B read, C write) + double bytes = (M * K + K * N + M * N) * sizeof(CDataType); + double bandwidth_gbs = (bytes / (time_ms * 1e-3)) / 1e9; + + // Validate + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C_gpu[i]) - float(K)) < 1.0f) + correct++; + } + + bool passed = (correct == M * N); + all_passed = all_passed && passed; + + char size_label[32]; + snprintf(size_label, sizeof(size_label), "%s %d³", label, M); + + printf(" %-9s | %9.4f | %6.2f | %9.1f | %s\n", + size_label, + time_ms, + tflops, + bandwidth_gbs, + passed ? "[OK]" : "[FAIL]"); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + } + + std::cout << "\n"; + + if(all_passed) + { + std::cout << "[OK] ALL PERFORMANCE TESTS PASSED\n"; + return 0; + } + else + { + std::cout << "[FAIL] SOME TESTS FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_real_kernel_simple.cpp b/dispatcher/tests/test_real_kernel_simple.cpp new file mode 100644 index 0000000000..72e3a5fc87 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_simple.cpp @@ -0,0 +1,201 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Simple real kernel test using tile_engine style (single kernel with -include) + * This follows the proven pattern from the examples + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header will be included via -include compiler flag +// It defines: ADataType, BDataType, CDataType, AccDataType, SelectedKernel, KERNEL_NAME + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +// Reference CPU GEMM +template +void reference_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + acc += float(A[m * K + k]) * float(B[k * N + n]); + } + C[m * N + n] = T(acc); + } + } +} + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Simple Real Kernel Test\n"; + std::cout << "=======================================\n\n"; + + // Test size (must be multiple of tile size) + const int M = 256; + const int N = 256; + const int K = 256; + + std::cout << "Problem: M=" << M << " N=" << N << " K=" << K << "\n"; + std::cout << "Kernel: " << KERNEL_NAME << "\n\n"; + + // Create kernel key + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 64}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + // Create and register kernel + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + std::cout << "OK Registered kernel\n"; + + // Create dispatcher + Dispatcher dispatcher; + Problem problem(M, N, K); + + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "[FAIL] Failed to select kernel\n"; + return 1; + } + std::cout << "OK Selected kernel: " << selected->get_name() << "\n\n"; + + // Prepare data + std::cout << "Preparing test data...\n"; + std::vector A_host(M * K); + std::vector B_host(K * N); + std::vector C_gpu(M * N); + std::vector C_cpu(M * N); + + // Simple test: A=1, B=1, C should be K + for(int i = 0; i < M * K; i++) + A_host[i] = ADataType(1.0f); + for(int i = 0; i < K * N; i++) + B_host[i] = BDataType(1.0f); + + // Allocate GPU memory + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + std::cout << "OK Data ready on GPU\n\n"; + + // Execute + std::cout << "Executing GPU kernel...\n"; + float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + std::cout << "OK GPU time: " << gpu_time << " ms\n"; + + double flops = 2.0 * M * N * K; + double tflops = (flops / (gpu_time * 1e-3)) / 1e12; + std::cout << "OK Performance: " << tflops << " TFLOPS\n\n"; + + // Copy result + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Validate + std::cout << "Validating (expected: all elements = " << K << ")...\n"; + + int correct = 0; + for(int i = 0; i < M * N; i++) + { + float val = float(C_gpu[i]); + if(std::abs(val - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + std::cout << "Accuracy: " << accuracy << "% (" << correct << "/" << M * N << ")\n"; + + // Show samples + std::cout << "\nFirst 5 results:\n"; + for(int i = 0; i < 5; i++) + { + std::cout << " C[" << i << "] = " << float(C_gpu[i]) << " (expected " << K << ")\n"; + } + std::cout << "\n"; + + // Cleanup + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + if(accuracy > 99.0f) + { + std::cout << "[OK] TEST PASSED\n"; + return 0; + } + else + { + std::cout << "[FAIL] TEST FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_registry.cpp b/dispatcher/tests/test_registry.cpp new file mode 100644 index 0000000000..4e5bf718df --- /dev/null +++ b/dispatcher/tests/test_registry.cpp @@ -0,0 +1,166 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for Registry using Google Test + +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +TEST(RegistryTest, Registration) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + + bool registered = registry.register_kernel(kernel); + EXPECT_TRUE(registered); + EXPECT_EQ(registry.size(), 1); +} + +TEST(RegistryTest, Lookup) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + registry.register_kernel(kernel); + + // Lookup by key + auto found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "test_kernel"); + + // Lookup by identifier + std::string id = key.encode_identifier(); + auto found2 = registry.lookup(id); + ASSERT_NE(found2, nullptr); + EXPECT_EQ(found2->get_name(), "test_kernel"); + + // Lookup non-existent + auto key2 = make_test_key(128); + auto not_found = registry.lookup(key2); + EXPECT_EQ(not_found, nullptr); +} + +TEST(RegistryTest, Priority) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel1 = std::make_shared(key, "kernel_low"); + auto kernel2 = std::make_shared(key, "kernel_high"); + + // Register with low priority + registry.register_kernel(kernel1, Registry::Priority::Low); + + // Try to register with normal priority (should replace) + bool replaced = registry.register_kernel(kernel2, Registry::Priority::Normal); + EXPECT_TRUE(replaced); + + auto found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_high"); + + // Try to register with low priority again (should fail) + auto kernel3 = std::make_shared(key, "kernel_low2"); + bool not_replaced = registry.register_kernel(kernel3, Registry::Priority::Low); + EXPECT_FALSE(not_replaced); + + found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_high"); +} + +TEST(RegistryTest, GetAll) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key1 = make_test_key(256); + auto key2 = make_test_key(128); + auto kernel1 = std::make_shared(key1, "kernel1"); + auto kernel2 = std::make_shared(key2, "kernel2"); + + registry.register_kernel(kernel1); + registry.register_kernel(kernel2); + + auto all = registry.get_all(); + EXPECT_EQ(all.size(), 2); +} + +TEST(RegistryTest, Filter) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + // Create kernels with different tile sizes + for(int tile_m : {128, 256, 512}) + { + auto key = make_test_key(tile_m); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(tile_m)); + registry.register_kernel(kernel); + } + + // Filter for large tiles (>= 256) + auto large_tiles = registry.filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 256; }); + + EXPECT_EQ(large_tiles.size(), 2); +} + +TEST(RegistryTest, Clear) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + registry.register_kernel(kernel); + + EXPECT_EQ(registry.size(), 1); + + registry.clear(); + EXPECT_EQ(registry.size(), 0); +} + +TEST(RegistryTest, MultipleKernels) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + // Register multiple kernels + for(int i = 0; i < 10; ++i) + { + auto key = make_test_key(256 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + registry.register_kernel(kernel); + } + + EXPECT_EQ(registry.size(), 10); + + // Verify all can be looked up + for(int i = 0; i < 10; ++i) + { + auto key = make_test_key(256 + i); + auto found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_" + std::to_string(i)); + } +} + +TEST(RegistryTest, Singleton) +{ + Registry& reg1 = Registry::instance(); + Registry& reg2 = Registry::instance(); + + // Should be the same instance + EXPECT_EQ(®1, ®2); +} diff --git a/dispatcher/tests/test_registry_extended.cpp b/dispatcher/tests/test_registry_extended.cpp new file mode 100644 index 0000000000..d173e1a38d --- /dev/null +++ b/dispatcher/tests/test_registry_extended.cpp @@ -0,0 +1,503 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for Registry - covers multiple registries, merging, filtering + +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +// ============================================================================= +// Basic Registration Tests +// ============================================================================= + +class RegistryBasicTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryBasicTest, RegisterSingleKernel) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + + EXPECT_TRUE(Registry::instance().register_kernel(kernel)); + EXPECT_EQ(Registry::instance().size(), 1); +} + +TEST_F(RegistryBasicTest, RegisterNullKernel) +{ + EXPECT_FALSE(Registry::instance().register_kernel(nullptr)); + EXPECT_EQ(Registry::instance().size(), 0); +} + +TEST_F(RegistryBasicTest, RegisterMultipleKernels) +{ + for(int i = 0; i < 100; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + EXPECT_TRUE(Registry::instance().register_kernel(kernel)); + } + EXPECT_EQ(Registry::instance().size(), 100); +} + +TEST_F(RegistryBasicTest, RegisterDuplicateKey) +{ + auto key = make_test_key(256); + auto kernel1 = std::make_shared(key, "kernel1"); + auto kernel2 = std::make_shared(key, "kernel2"); + + EXPECT_TRUE(Registry::instance().register_kernel(kernel1, Registry::Priority::Normal)); + + // Same priority should not replace + EXPECT_FALSE(Registry::instance().register_kernel(kernel2, Registry::Priority::Normal)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "kernel1"); +} + +// ============================================================================= +// Priority Tests +// ============================================================================= + +class RegistryPriorityTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryPriorityTest, HigherPriorityReplaces) +{ + auto key = make_test_key(256); + + auto low = std::make_shared(key, "low"); + auto normal = std::make_shared(key, "normal"); + auto high = std::make_shared(key, "high"); + + EXPECT_TRUE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "low"); + + EXPECT_TRUE(Registry::instance().register_kernel(normal, Registry::Priority::Normal)); + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "normal"); + + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "high"); +} + +TEST_F(RegistryPriorityTest, LowerPriorityDoesNotReplace) +{ + auto key = make_test_key(256); + + auto high = std::make_shared(key, "high"); + auto low = std::make_shared(key, "low"); + + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + EXPECT_FALSE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "high"); +} + +TEST_F(RegistryPriorityTest, SamePriorityDoesNotReplace) +{ + auto key = make_test_key(256); + + auto first = std::make_shared(key, "first"); + auto second = std::make_shared(key, "second"); + + EXPECT_TRUE(Registry::instance().register_kernel(first, Registry::Priority::Normal)); + EXPECT_FALSE(Registry::instance().register_kernel(second, Registry::Priority::Normal)); + + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "first"); +} + +// ============================================================================= +// Lookup Tests +// ============================================================================= + +class RegistryLookupTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + // Register several kernels + for(int tile : {128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryLookupTest, LookupByKey) +{ + auto key = make_test_key(256); + auto found = Registry::instance().lookup(key); + + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_256"); +} + +TEST_F(RegistryLookupTest, LookupByIdentifier) +{ + auto key = make_test_key(256); + std::string id = key.encode_identifier(); + + auto found = Registry::instance().lookup(id); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_256"); +} + +TEST_F(RegistryLookupTest, LookupNonExistent) +{ + auto key = make_test_key(1024); // Not registered + EXPECT_EQ(Registry::instance().lookup(key), nullptr); + EXPECT_EQ(Registry::instance().lookup("nonexistent_id"), nullptr); +} + +TEST_F(RegistryLookupTest, LookupEmptyIdentifier) +{ + EXPECT_EQ(Registry::instance().lookup(""), nullptr); +} + +// ============================================================================= +// Filter Tests +// ============================================================================= + +class RegistryFilterTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + // Register kernels with various tile sizes + for(int tile : {64, 128, 256, 512, 1024}) + { + auto key = make_test_key(tile); + key.signature.dtype_a = (tile < 256) ? DataType::FP16 : DataType::BF16; + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryFilterTest, FilterByTileSize) +{ + auto large = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 256; }); + + EXPECT_EQ(large.size(), 3); // 256, 512, 1024 +} + +TEST_F(RegistryFilterTest, FilterByDataType) +{ + auto fp16 = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().signature.dtype_a == DataType::FP16; }); + + EXPECT_EQ(fp16.size(), 2); // 64, 128 +} + +TEST_F(RegistryFilterTest, FilterMatchesNone) +{ + auto none = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m > 2048; }); + + EXPECT_EQ(none.size(), 0); +} + +TEST_F(RegistryFilterTest, FilterMatchesAll) +{ + auto all = Registry::instance().filter([](const KernelInstance& k) { return true; }); + + EXPECT_EQ(all.size(), 5); +} + +// ============================================================================= +// Multiple Registries Tests +// ============================================================================= + +class MultipleRegistriesTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(MultipleRegistriesTest, CreateIndependentRegistries) +{ + Registry reg1; + Registry reg2; + + reg1.set_name("registry1"); + reg2.set_name("registry2"); + + auto key1 = make_test_key(256); + auto key2 = make_test_key(512); + + reg1.register_kernel(std::make_shared(key1, "kernel1")); + reg2.register_kernel(std::make_shared(key2, "kernel2")); + + EXPECT_EQ(reg1.size(), 1); + EXPECT_EQ(reg2.size(), 1); + + EXPECT_NE(reg1.lookup(key1), nullptr); + EXPECT_EQ(reg1.lookup(key2), nullptr); + + EXPECT_EQ(reg2.lookup(key1), nullptr); + EXPECT_NE(reg2.lookup(key2), nullptr); +} + +TEST_F(MultipleRegistriesTest, RegistryNaming) +{ + Registry reg; + reg.set_name("my_custom_registry"); + + EXPECT_EQ(reg.get_name(), "my_custom_registry"); +} + +TEST_F(MultipleRegistriesTest, MergeRegistries) +{ + Registry reg1; + Registry reg2; + + auto key1 = make_test_key(128); + auto key2 = make_test_key(256); + auto key3 = make_test_key(512); + + reg1.register_kernel(std::make_shared(key1, "k1")); + reg1.register_kernel(std::make_shared(key2, "k2")); + + reg2.register_kernel(std::make_shared(key3, "k3")); + + Registry combined; + combined.merge_from(reg1, Registry::Priority::Normal); + combined.merge_from(reg2, Registry::Priority::Normal); + + EXPECT_EQ(combined.size(), 3); + EXPECT_NE(combined.lookup(key1), nullptr); + EXPECT_NE(combined.lookup(key2), nullptr); + EXPECT_NE(combined.lookup(key3), nullptr); +} + +TEST_F(MultipleRegistriesTest, MergeWithPriorityConflict) +{ + Registry reg1; + Registry reg2; + + auto key = make_test_key(256); + + reg1.register_kernel(std::make_shared(key, "from_reg1")); + reg2.register_kernel(std::make_shared(key, "from_reg2")); + + Registry combined; + combined.merge_from(reg1, Registry::Priority::Low); + combined.merge_from(reg2, Registry::Priority::High); + + EXPECT_EQ(combined.size(), 1); + EXPECT_EQ(combined.lookup(key)->get_name(), "from_reg2"); +} + +TEST_F(MultipleRegistriesTest, SingletonIndependence) +{ + Registry local_reg; + local_reg.set_name("local"); + + auto key1 = make_test_key(256); + auto key2 = make_test_key(512); + + local_reg.register_kernel(std::make_shared(key1, "local_kernel")); + Registry::instance().register_kernel( + std::make_shared(key2, "global_kernel")); + + EXPECT_EQ(local_reg.size(), 1); + EXPECT_EQ(Registry::instance().size(), 1); + + EXPECT_NE(local_reg.lookup(key1), nullptr); + EXPECT_EQ(local_reg.lookup(key2), nullptr); + + EXPECT_EQ(Registry::instance().lookup(key1), nullptr); + EXPECT_NE(Registry::instance().lookup(key2), nullptr); +} + +// ============================================================================= +// Thread Safety Tests +// ============================================================================= + +class RegistryThreadSafetyTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryThreadSafetyTest, ConcurrentRegistrations) +{ + const int num_threads = 10; + const int kernels_per_thread = 100; + + std::vector threads; + std::atomic success_count{0}; + + for(int t = 0; t < num_threads; t++) + { + threads.emplace_back([t, kernels_per_thread, &success_count]() { + for(int k = 0; k < kernels_per_thread; k++) + { + int tile = t * 1000 + k; // Unique tile size + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + + if(Registry::instance().register_kernel(kernel)) + { + success_count++; + } + } + }); + } + + for(auto& t : threads) + { + t.join(); + } + + EXPECT_EQ(success_count.load(), num_threads * kernels_per_thread); + EXPECT_EQ(Registry::instance().size(), num_threads * kernels_per_thread); +} + +TEST_F(RegistryThreadSafetyTest, ConcurrentLookups) +{ + // Pre-register kernels + for(int i = 0; i < 100; i++) + { + auto key = make_test_key(i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + const int num_threads = 10; + const int lookups_per_thread = 1000; + std::atomic found_count{0}; + + std::vector threads; + for(int t = 0; t < num_threads; t++) + { + threads.emplace_back([lookups_per_thread, &found_count]() { + for(int k = 0; k < lookups_per_thread; k++) + { + auto key = make_test_key(k % 100); + if(Registry::instance().lookup(key) != nullptr) + { + found_count++; + } + } + }); + } + + for(auto& t : threads) + { + t.join(); + } + + EXPECT_EQ(found_count.load(), num_threads * lookups_per_thread); +} + +// ============================================================================= +// Clear and Size Tests +// ============================================================================= + +class RegistryClearTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryClearTest, ClearEmptyRegistry) +{ + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); + + Registry::instance().clear(); // Should not crash + EXPECT_EQ(Registry::instance().size(), 0); +} + +TEST_F(RegistryClearTest, ClearNonEmptyRegistry) +{ + for(int i = 0; i < 10; i++) + { + auto key = make_test_key(i); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + EXPECT_EQ(Registry::instance().size(), 10); + + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); +} + +TEST_F(RegistryClearTest, RegisterAfterClear) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); + + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); + + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); +} + +// ============================================================================= +// GetAll Tests +// ============================================================================= + +class RegistryGetAllTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryGetAllTest, GetAllEmpty) +{ + auto all = Registry::instance().get_all(); + EXPECT_EQ(all.size(), 0); +} + +TEST_F(RegistryGetAllTest, GetAllMultiple) +{ + for(int i = 0; i < 5; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + auto all = Registry::instance().get_all(); + EXPECT_EQ(all.size(), 5); +} diff --git a/dispatcher/tests/test_regression.cpp b/dispatcher/tests/test_regression.cpp new file mode 100644 index 0000000000..8b5a416ecf --- /dev/null +++ b/dispatcher/tests/test_regression.cpp @@ -0,0 +1,492 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Regression tests for known issues and edge cases. + * Add a new test here whenever a bug is fixed to prevent regression. + */ + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "test_mock_kernel.hpp" +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; +using SelectionStrategy = Dispatcher::SelectionStrategy; + +// ============================================================================= +// Issue: Uninitialized 'grouped' field in KernelKey caused JSON corruption +// Fix: Ensure all fields in make_test_key() are initialized +// ============================================================================= + +class RegressionGroupedFieldTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionGroupedFieldTest, GroupedFieldInitialized) +{ + KernelKey key = make_test_key(256); + + // grouped should be explicitly initialized + EXPECT_FALSE(key.signature.grouped); + + // Encoding should not crash or produce garbage + std::string id = key.encode_identifier(); + EXPECT_FALSE(id.empty()); + + // ID should not contain garbage characters + for(char c : id) + { + EXPECT_TRUE(std::isprint(c) || c == '_' || c == '-') + << "Invalid character in identifier: " << static_cast(c); + } +} + +TEST_F(RegressionGroupedFieldTest, GroupedFieldInJSON) +{ + KernelKey key = make_test_key(256); + key.signature.grouped = false; + + auto kernel = std::make_shared(key, "test_kernel"); + Registry::instance().register_kernel(kernel); + + // Export to JSON + std::string json = Registry::instance().export_json(true); + + // JSON should be valid (not contain null bytes or garbage) + EXPECT_FALSE(json.empty()); + + // Should contain the grouped field with proper value + EXPECT_NE(json.find("\"grouped\""), std::string::npos); + EXPECT_NE(json.find("false"), std::string::npos); +} + +// ============================================================================= +// Issue: Priority comparison was incorrect +// Fix: Higher priority should replace lower, same priority should not replace +// ============================================================================= + +class RegressionPriorityTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionPriorityTest, LowThenHighReplaces) +{ + auto key = make_test_key(256); + auto low = std::make_shared(key, "low"); + auto high = std::make_shared(key, "high"); + + EXPECT_TRUE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "high"); +} + +TEST_F(RegressionPriorityTest, HighThenLowDoesNotReplace) +{ + auto key = make_test_key(256); + auto high = std::make_shared(key, "high"); + auto low = std::make_shared(key, "low"); + + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + EXPECT_FALSE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "high"); +} + +TEST_F(RegressionPriorityTest, SamePriorityDoesNotReplace) +{ + auto key = make_test_key(256); + auto first = std::make_shared(key, "first"); + auto second = std::make_shared(key, "second"); + + EXPECT_TRUE(Registry::instance().register_kernel(first, Registry::Priority::Normal)); + EXPECT_FALSE(Registry::instance().register_kernel(second, Registry::Priority::Normal)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "first"); +} + +// ============================================================================= +// Issue: Empty heuristic caused crash +// Fix: Fall back to FirstFit when heuristic returns empty or invalid results +// ============================================================================= + +class RegressionHeuristicTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionHeuristicTest, EmptyHeuristicFallback) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {}; // Empty + }); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + + // Should not crash, should fall back to FirstFit + auto selected = dispatcher.select_kernel(problem); + EXPECT_NE(selected, nullptr); +} + +TEST_F(RegressionHeuristicTest, AllInvalidHeuristicFallback) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {"invalid1", "invalid2", "invalid3"}; + }); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + + // Should not crash, should fall back to FirstFit + auto selected = dispatcher.select_kernel(problem); + EXPECT_NE(selected, nullptr); +} + +TEST_F(RegressionHeuristicTest, NullHeuristicSafe) +{ + Dispatcher dispatcher; + + // Don't set any heuristic + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + + // Should not crash + auto selected = dispatcher.select_kernel(problem); + // Behavior depends on implementation - may return nullptr or fall back +} + +// ============================================================================= +// Issue: Lookup by empty string caused crash or undefined behavior +// ============================================================================= + +class RegressionLookupTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionLookupTest, EmptyStringLookup) +{ + EXPECT_EQ(Registry::instance().lookup(""), nullptr); +} + +TEST_F(RegressionLookupTest, VeryLongStringLookup) +{ + std::string very_long(10000, 'x'); + EXPECT_EQ(Registry::instance().lookup(very_long), nullptr); +} + +TEST_F(RegressionLookupTest, SpecialCharactersLookup) +{ + EXPECT_EQ(Registry::instance().lookup("kernel\0name"), nullptr); + EXPECT_EQ(Registry::instance().lookup("kernel\nname"), nullptr); + EXPECT_EQ(Registry::instance().lookup("kernel\tname"), nullptr); +} + +// ============================================================================= +// Issue: Problem with zero dimensions passed to dispatcher +// ============================================================================= + +class RegressionProblemTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionProblemTest, ZeroMDimension) +{ + Problem problem; + problem.M = 0; + problem.N = 1024; + problem.K = 1024; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionProblemTest, ZeroNDimension) +{ + Problem problem; + problem.M = 1024; + problem.N = 0; + problem.K = 1024; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionProblemTest, ZeroKDimension) +{ + Problem problem; + problem.M = 1024; + problem.N = 1024; + problem.K = 0; + + EXPECT_FALSE(problem.is_valid()); +} + +// ============================================================================= +// Issue: Dispatcher run with null pointers +// ============================================================================= + +class RegressionNullPointerTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionNullPointerTest, RunWithNullPointers) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Mock kernel doesn't use pointers, so this should work + float time = dispatcher.run(nullptr, nullptr, nullptr, problem); + + // Mock returns 1.0f + EXPECT_FLOAT_EQ(time, 1.0f); +} + +// ============================================================================= +// Issue: Thread safety - concurrent access to singleton +// ============================================================================= + +class RegressionThreadSafetyTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionThreadSafetyTest, SingletonAddressStable) +{ + Registry* addr1 = &Registry::instance(); + Registry* addr2 = &Registry::instance(); + Registry* addr3 = &Registry::instance(); + + EXPECT_EQ(addr1, addr2); + EXPECT_EQ(addr2, addr3); +} + +// ============================================================================= +// Issue: encode_identifier could produce duplicate IDs for different configs +// ============================================================================= + +class RegressionIdentifierTest : public ::testing::Test +{ +}; + +TEST_F(RegressionIdentifierTest, DifferentConfigsDifferentIDs) +{ + // Create two keys that differ only in one field + KernelKey key1 = make_test_key(256); + KernelKey key2 = make_test_key(256); + key2.algorithm.persistent = true; // Only difference + + std::string id1 = key1.encode_identifier(); + std::string id2 = key2.encode_identifier(); + + EXPECT_NE(id1, id2) << "Different persistent flag should produce different IDs"; +} + +TEST_F(RegressionIdentifierTest, DifferentTileShapesDifferentIDs) +{ + KernelKey key1 = make_test_key(128, 128, 32); + KernelKey key2 = make_test_key(256, 256, 32); + + EXPECT_NE(key1.encode_identifier(), key2.encode_identifier()); +} + +TEST_F(RegressionIdentifierTest, DifferentWarpConfigsDifferentIDs) +{ + KernelKey key1 = make_test_key(256); + key1.algorithm.wave_shape = {2, 2, 1}; + + KernelKey key2 = make_test_key(256); + key2.algorithm.wave_shape = {4, 1, 1}; + + EXPECT_NE(key1.encode_identifier(), key2.encode_identifier()); +} + +// ============================================================================= +// Issue: Negative k_batch could cause issues +// ============================================================================= + +class RegressionKBatchTest : public ::testing::Test +{ +}; + +TEST_F(RegressionKBatchTest, ZeroKBatchInvalid) +{ + Problem problem(1024, 1024, 1024); + problem.k_batch = 0; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionKBatchTest, NegativeKBatchInvalid) +{ + Problem problem(1024, 1024, 1024); + problem.k_batch = -1; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionKBatchTest, LargeKBatchValid) +{ + Problem problem(1024, 1024, 1024); + problem.k_batch = 1000; + + EXPECT_TRUE(problem.is_valid()); +} + +// ============================================================================= +// Issue: Filter returning shared_ptr leaks +// ============================================================================= + +class RegressionFilterTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + for(int i = 0; i < 10; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionFilterTest, FilterResultsAreValid) +{ + auto results = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 105; }); + + EXPECT_EQ(results.size(), 5); + + for(const auto& kernel : results) + { + EXPECT_NE(kernel, nullptr); + EXPECT_GE(kernel->get_key().algorithm.tile_shape.m, 105); + } +} + +// ============================================================================= +// Issue: Double clear() could cause issues +// ============================================================================= + +class RegressionDoubleClearTest : public ::testing::Test +{ +}; + +TEST_F(RegressionDoubleClearTest, DoubleClearSafe) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); + + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); + + Registry::instance().clear(); // Second clear + EXPECT_EQ(Registry::instance().size(), 0); + + // Should still work after double clear + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); +} + +// ============================================================================= +// Issue: Multiple dispatchers with same registry +// ============================================================================= + +class RegressionMultiDispatcherTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionMultiDispatcherTest, MultipleDispatchersShareRegistry) +{ + Dispatcher d1; + Dispatcher d2; + Dispatcher d3; + + Problem problem(1024, 1024, 1024); + + auto k1 = d1.select_kernel(problem); + auto k2 = d2.select_kernel(problem); + auto k3 = d3.select_kernel(problem); + + // All should select the same kernel + EXPECT_NE(k1, nullptr); + EXPECT_EQ(k1, k2); + EXPECT_EQ(k2, k3); +} diff --git a/dispatcher/tests/test_sanity_ck_tile.cpp b/dispatcher/tests/test_sanity_ck_tile.cpp new file mode 100644 index 0000000000..fd28b7e54c --- /dev/null +++ b/dispatcher/tests/test_sanity_ck_tile.cpp @@ -0,0 +1,607 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Sanity check tests to verify CK Tile kernels are actually running on GPU. + * + * These tests verify: + * 1. GPU memory allocation and transfer work correctly + * 2. The dispatcher calls CK Tile infrastructure + * 3. GPU computes correct results (not just zeros) + * 4. Performance is reasonable (not CPU fallback) + * 5. Different problem sizes work correctly + */ + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header will be included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error at " << __FILE__ << ":" << __LINE__ << ": " \ + << hipGetErrorString(err) << "\n"; \ + return 1; \ + } \ + } + +// Reference CPU GEMM for validation +template +void cpu_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + acc += float(A[m * K + k]) * float(B[k * N + n]); + } + C[m * N + n] = T(acc); + } + } +} + +// Test helper to setup dispatcher +void setup_dispatcher() +{ + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 64}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Registry::Priority::High); +} + +// ============================================================================= +// Test 1: Basic Sanity - All ones multiplication +// ============================================================================= +int test_all_ones() +{ + std::cout << "\n=== Test: All Ones Multiplication ===\n"; + + const int M = 256, N = 256, K = 256; + + std::vector A(M * K, ADataType(1.0f)); + std::vector B(K * N, BDataType(1.0f)); + std::vector C(M * N, CDataType(0.0f)); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // All ones * all ones with K=256 should give K=256 for each element + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C[i]) - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Time: " << time << " ms\n"; + std::cout << " Expected: " << K << "\n"; + std::cout << " Sample C[0]: " << float(C[0]) << "\n"; + std::cout << " Accuracy: " << accuracy << "%\n"; + + if(accuracy < 99.0f) + { + std::cerr << " FAILED: Accuracy too low\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 2: Non-Zero Results - Verify GPU actually computed something +// ============================================================================= +int test_non_zero_results() +{ + std::cout << "\n=== Test: Non-Zero Results ===\n"; + + const int M = 256, N = 256, K = 256; + + std::vector A(M * K, ADataType(2.0f)); // All 2s + std::vector B(K * N, BDataType(3.0f)); // All 3s + std::vector C(M * N, CDataType(0.0f)); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // 2 * 3 * K = 6 * 256 = 1536 + float expected = 6.0f * K; + int correct = 0; + int non_zero = 0; + + for(int i = 0; i < M * N; i++) + { + if(float(C[i]) != 0.0f) + non_zero++; + if(std::abs(float(C[i]) - expected) < 10.0f) + { + correct++; + } + } + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Time: " << time << " ms\n"; + std::cout << " Expected: " << expected << "\n"; + std::cout << " Sample C[0]: " << float(C[0]) << "\n"; + std::cout << " Non-zero elements: " << non_zero << "/" << M * N << "\n"; + + if(non_zero == 0) + { + std::cerr << " FAILED: All zeros - GPU may not have run\n"; + return 1; + } + + float accuracy = 100.0f * correct / (M * N); + std::cout << " Accuracy: " << accuracy << "%\n"; + + if(accuracy < 99.0f) + { + std::cerr << " FAILED: Accuracy too low\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 3: Performance Check - Ensure not CPU fallback +// ============================================================================= +int test_performance() +{ + std::cout << "\n=== Test: Performance Check ===\n"; + + const int M = 1024, N = 1024, K = 1024; + const int num_runs = 5; + + std::vector A(M * K, ADataType(1.0f)); + std::vector B(K * N, BDataType(1.0f)); + std::vector C(M * N); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + // Warmup + dispatcher.run(A_dev, B_dev, C_dev, problem); + HIP_CHECK(hipDeviceSynchronize()); + + // Timed runs + std::vector times; + for(int i = 0; i < num_runs; i++) + { + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + times.push_back(time); + } + + float avg_time = std::accumulate(times.begin(), times.end(), 0.0f) / times.size(); + float min_time = *std::min_element(times.begin(), times.end()); + + double flops = 2.0 * M * N * K; + double tflops = (flops / (min_time * 1e-3)) / 1e12; + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Problem: " << M << "x" << N << "x" << K << "\n"; + std::cout << " Avg time: " << avg_time << " ms\n"; + std::cout << " Min time: " << min_time << " ms\n"; + std::cout << " Performance: " << tflops << " TFLOPS\n"; + + // GPU should achieve at least 1 TFLOPS for this size + // CPU would be ~0.001 TFLOPS + if(tflops < 1.0) + { + std::cerr << " FAILED: Performance too low - may be CPU fallback\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 4: CPU vs GPU Correctness +// ============================================================================= +int test_vs_cpu_reference() +{ + std::cout << "\n=== Test: CPU vs GPU Correctness ===\n"; + + const int M = 128, N = 128, K = 128; // Small for CPU reference + + // Random-ish values + std::vector A(M * K); + std::vector B(K * N); + std::vector C_gpu(M * N); + std::vector C_cpu(M * N); + + for(int i = 0; i < M * K; i++) + { + A[i] = ADataType(float((i % 10) + 1) * 0.1f); + } + for(int i = 0; i < K * N; i++) + { + B[i] = BDataType(float((i % 7) + 1) * 0.1f); + } + + // CPU reference + cpu_gemm(A, B, C_cpu, M, N, K); + + // GPU + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Compare + float max_diff = 0.0f; + float sum_diff = 0.0f; + int correct = 0; + + for(int i = 0; i < M * N; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + float diff = std::abs(gpu_val - cpu_val); + + max_diff = std::max(max_diff, diff); + sum_diff += diff; + + // FP16 has limited precision (~3-4 decimal digits) + // For K=128, values can reach ~10-30, so allow 5% relative error + absolute tolerance + float tolerance = std::max(std::abs(cpu_val) * 0.05f, 1.0f); + if(diff < tolerance) + { + correct++; + } + } + + float avg_diff = sum_diff / (M * N); + float accuracy = 100.0f * correct / (M * N); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Max diff: " << max_diff << "\n"; + std::cout << " Avg diff: " << avg_diff << "\n"; + std::cout << " Sample CPU C[0]: " << float(C_cpu[0]) << "\n"; + std::cout << " Sample GPU C[0]: " << float(C_gpu[0]) << "\n"; + std::cout << " Accuracy: " << accuracy << "%\n"; + + // FP16 accumulation can have significant rounding differences from CPU FP32 + // 90% is reasonable for FP16 with K=128 accumulation + if(accuracy < 90.0f) + { + std::cerr << " FAILED: Too many mismatches vs CPU\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 5: Different Problem Sizes +// ============================================================================= +int test_multiple_sizes() +{ + std::cout << "\n=== Test: Multiple Problem Sizes ===\n"; + + std::vector> sizes = { + {128, 128, 128}, + {256, 256, 256}, + {512, 512, 512}, + {128, 256, 512}, + {512, 256, 128}, + {1024, 1024, 256}, + }; + + int passed = 0; + int total = sizes.size(); + + for(const auto& [M, N, K] : sizes) + { + std::cout << " Testing " << M << "x" << N << "x" << K << "... "; + + std::vector A(M * K, ADataType(1.0f)); + std::vector B(K * N, BDataType(1.0f)); + std::vector C(M * N); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + hipMalloc(&A_dev, M * K * sizeof(ADataType)); + hipMalloc(&B_dev, K * N * sizeof(BDataType)); + hipMalloc(&C_dev, M * N * sizeof(CDataType)); + + hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice); + hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice); + hipMemset(C_dev, 0, M * N * sizeof(CDataType)); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + hipMemcpy(C.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); + + hipFree(A_dev); + hipFree(B_dev); + hipFree(C_dev); + + // Check result + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C[i]) - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + + if(accuracy > 99.0f && time > 0) + { + std::cout << "PASS (" << time << " ms)\n"; + passed++; + } + else + { + std::cout << "FAIL (acc=" << accuracy << "%, time=" << time << ")\n"; + } + } + + std::cout << "\n Passed: " << passed << "/" << total << "\n"; + + if(passed < total) + { + std::cerr << " FAILED: Some sizes failed\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 6: Memory Bounds Check +// ============================================================================= +int test_memory_bounds() +{ + std::cout << "\n=== Test: Memory Bounds Check ===\n"; + + const int M = 256, N = 256, K = 256; + const float sentinel = -999.0f; + + // Allocate with extra padding and sentinel values + const int padding = 16; + std::vector A(M * K + padding, ADataType(1.0f)); + std::vector B(K * N + padding, BDataType(1.0f)); + std::vector C(M * N + padding, CDataType(sentinel)); + + // Set sentinels at the end + for(int i = 0; i < padding; i++) + { + A[M * K + i] = ADataType(sentinel); + B[K * N + i] = BDataType(sentinel); + } + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, (M * K + padding) * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, (K * N + padding) * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, (M * N + padding) * sizeof(CDataType))); + + HIP_CHECK( + hipMemcpy(A_dev, A.data(), (M * K + padding) * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(B_dev, B.data(), (K * N + padding) * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(C_dev, C.data(), (M * N + padding) * sizeof(CDataType), hipMemcpyHostToDevice)); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK( + hipMemcpy(C.data(), C_dev, (M * N + padding) * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Check sentinels weren't overwritten + bool sentinels_intact = true; + for(int i = 0; i < padding; i++) + { + if(float(C[M * N + i]) != sentinel) + { + sentinels_intact = false; + std::cerr << " Sentinel overwritten at position " << (M * N + i) << "\n"; + } + } + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + if(!sentinels_intact) + { + std::cerr << " FAILED: Memory bounds violated\n"; + return 1; + } + + // Also check actual results are correct + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C[i]) - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + std::cout << " Sentinels intact: Yes\n"; + std::cout << " Result accuracy: " << accuracy << "%\n"; + + if(accuracy < 99.0f) + { + std::cerr << " FAILED: Results incorrect\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Main +// ============================================================================= +int main() +{ + std::cout << "========================================\n"; + std::cout << "CK Tile Sanity Check Tests\n"; + std::cout << "========================================\n"; + std::cout << "Kernel: " << KERNEL_NAME << "\n"; + + // Setup + setup_dispatcher(); + + int failures = 0; + + // Run all tests + failures += test_all_ones(); + failures += test_non_zero_results(); + failures += test_performance(); + failures += test_vs_cpu_reference(); + failures += test_multiple_sizes(); + failures += test_memory_bounds(); + + std::cout << "\n========================================\n"; + if(failures == 0) + { + std::cout << "ALL TESTS PASSED\n"; + std::cout << "CK Tile is running correctly on GPU.\n"; + return 0; + } + else + { + std::cout << failures << " TEST(S) FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_tile_backend.cpp b/dispatcher/tests/test_tile_backend.cpp new file mode 100644 index 0000000000..4e7c693071 --- /dev/null +++ b/dispatcher/tests/test_tile_backend.cpp @@ -0,0 +1,155 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for CK Tile backend using Google Test +/// Note: This test validates the dispatcher wrapper infrastructure, not actual kernel execution + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +namespace { + +// Note: Actual CK Tile backend tests require real generated kernels and GPU hardware. +// These tests verify the dispatcher's tile backend interface and wrapper functionality +// using mock kernels instead of real tile kernels. +} // anonymous namespace + +// These tests verify the tile backend can be used with mock kernels +// Real tile kernel integration would require generated CK Tile kernels + +TEST(TileBackendTest, KernelKeyCreation) +{ + // Test creating a kernel key for tile backend + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + + EXPECT_EQ(key.algorithm.tile_shape.m, 256); + EXPECT_EQ(key.algorithm.tile_shape.n, 256); + EXPECT_EQ(key.algorithm.tile_shape.k, 32); + EXPECT_EQ(key.gfx_arch, "gfx942"); + EXPECT_EQ(key.signature.dtype_a, DataType::FP16); +} + +TEST(TileBackendTest, MockKernelRegistration) +{ + // Clear registry for clean test + Registry::instance().clear(); + + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + auto kernel = + std::make_shared(key, "mock_tile_kernel", false); // strict divisibility + + // Register kernel + bool registered = Registry::instance().register_kernel(kernel); + EXPECT_TRUE(registered); + + // Lookup kernel + std::string kernel_id = key.encode_identifier(); + auto found_kernel = Registry::instance().lookup(kernel_id); + EXPECT_NE(found_kernel, nullptr); + EXPECT_EQ(found_kernel->get_name(), "mock_tile_kernel"); + + Registry::instance().clear(); +} + +TEST(TileBackendTest, DispatcherWithMockTileKernel) +{ + // Clear registry + Registry::instance().clear(); + + // Create and register mock tile kernel + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + auto kernel = + std::make_shared(key, "mock_tile_kernel", false); // strict divisibility + Registry::instance().register_kernel(kernel); + + // Create dispatcher + Dispatcher dispatcher; + + // Test kernel selection - divisible dimensions + Problem problem1(512, 512, 512); // Divisible by 256, 256, 32 + auto selected1 = dispatcher.select_kernel(problem1); + EXPECT_NE(selected1, nullptr); + EXPECT_EQ(selected1->get_name(), "mock_tile_kernel"); + + // Test with non-divisible problem + Problem problem2(100, 200, 300); // Not divisible + auto not_selected = dispatcher.select_kernel(problem2); + EXPECT_EQ(not_selected, nullptr); + + Registry::instance().clear(); +} + +TEST(TileBackendTest, TileKernelIdentifierEncoding) +{ + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + + std::string id = key.encode_identifier(); + + // Should contain tile dimensions + EXPECT_NE(id.find("256x256x32"), std::string::npos); + EXPECT_NE(id.find("2x2x1"), std::string::npos); + EXPECT_NE(id.find("32x32x16"), std::string::npos); + + // Should contain persistent flag + EXPECT_NE(id.find("nopers"), std::string::npos); // persistent = false +} + +TEST(TileBackendTest, MultipleKernelRegistration) +{ + // Clear registry + Registry::instance().clear(); + + // Register multiple kernels with different tile sizes + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + auto kernel1 = std::make_shared(key1, "kernel_256x256x32", false); + + KernelKey key2 = make_test_key(128, 128, 64, "gfx942"); + auto kernel2 = std::make_shared(key2, "kernel_128x128x64", false); + + Registry::instance().register_kernel(kernel1); + Registry::instance().register_kernel(kernel2); + + EXPECT_EQ(Registry::instance().size(), 2); + + // Verify both are accessible + auto found1 = Registry::instance().lookup(key1.encode_identifier()); + auto found2 = Registry::instance().lookup(key2.encode_identifier()); + + EXPECT_NE(found1, nullptr); + EXPECT_NE(found2, nullptr); + EXPECT_EQ(found1->get_name(), "kernel_256x256x32"); + EXPECT_EQ(found2->get_name(), "kernel_128x128x64"); + + Registry::instance().clear(); +} + +TEST(TileBackendTest, TileSizeSupport) +{ + Registry::instance().clear(); + + // Create kernel with 256x256x32 tiles (no padding) + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + auto kernel = + std::make_shared(key, "test_kernel", false); // strict divisibility + + // Should support 512x512x512 (divisible) + EXPECT_TRUE(kernel->supports(Problem(512, 512, 512))); + + // Should support 256x256x32 (exact match) + EXPECT_TRUE(kernel->supports(Problem(256, 256, 32))); + + // Should NOT support 100x200x300 (not divisible) + EXPECT_FALSE(kernel->supports(Problem(100, 200, 300))); + + // Should support 1024x1024x1024 (divisible) + EXPECT_TRUE(kernel->supports(Problem(1024, 1024, 1024))); + + Registry::instance().clear(); +} From 31a35ecab4e403f63ec4b76f4a709c21172c39de Mon Sep 17 00:00:00 2001 From: kensclin Date: Fri, 23 Jan 2026 01:39:38 +0800 Subject: [PATCH 09/42] GEMM Blockscale ABQuant Optimization (#3620) * GEMM Blockscale ABQuant Optimization * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix precommit error * clean * Fix --------- Co-authored-by: Thomas Ning Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Ding, Yi --- .../gemm_abquant_quantgrouped.cpp | 16 ++-- .../38_block_scale_gemm/gemm_utils.hpp | 29 +++++++ .../run_gemm_quant_example.inc | 4 +- include/ck_tile/core/tensor/sweep_tile.hpp | 12 +-- ...versal_gemm_ar_aquant_flatbr_bquant_cr.hpp | 44 +++++++++- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 81 +++++++++++++------ .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 26 ++++-- 7 files changed, 161 insertions(+), 51 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 155f19881e..b1cd1a52a7 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -4,7 +4,13 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigABQuantPrefill; + +template +using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; + +// template +// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; void abquant_quantgrouped_instance_factory( std::unordered_map>& lut) @@ -78,7 +84,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, @@ -93,7 +99,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, @@ -108,7 +114,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, @@ -123,7 +129,7 @@ void abquant_quantgrouped_instance_factory( using BQuantGroupSize = ck_tile::QuantGroupShape>; using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, + return run_gemm_example_prec_type, TypeConfig, AQuantGroupSize, BQuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 37fc998e5b..a95ca4862c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -192,6 +192,28 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill static constexpr bool PreshuffleQuant = true; }; +template +struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQuant_Prefill +{ + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + +template +struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + template struct GemmConfigQuantPrefill : public GemmConfigBase { @@ -209,6 +231,13 @@ struct GemmConfigQuantPrefill : public GemmConfigBase ck_tile::get_k_warp_tile(); }; +template +struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill +{ + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + template struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 607c53d9af..912527c929 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -33,6 +33,7 @@ template ); + constexpr bool transpose_c = QuantMode == ck_tile::QuantType::ABQuantGrouped; using ComputeDataType = std::conditional_t; using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase > CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f) { - using DstrSpan = remove_cvref_t; + using DstrSpanImpl = typename remove_cvref_t::Impl; - static_ford{}([&](auto dstr_idx_impl) { - constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl); - - f(dstr_idx); - }); + if constexpr(DstrSpanImpl::size() == 0) // handle the 0-dim span case + f(detail::make_tile_distributed_index(sequence<>{})); + else + static_ford{}( + [&](auto dstr_idx_impl) { f(detail::make_tile_distributed_index(dstr_idx_impl)); }); } // unpacked span, this version support span with unpack(multi-arg) functor diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index 63a5151108..b4a1bf886e 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -213,6 +213,22 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg }); }); }; + + auto q_block_tensor = aq_block_tensor; + constexpr bool SimpleDequant = + Traits::NQPerBlock == 1 && + AccTensor::get_distributed_spans()[I0].impl_.size() == 0; // c_transpose + if constexpr(SimpleDequant) + { + constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); + sweep_tile_span(aq_spans[I0], [&](auto im) { + sweep_tile_span(aq_spans[I1], [&](auto ik) { + q_block_tensor(make_tuple(im, ik)) *= + bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); + }); + }); + } + // hot loop: static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { zero_accumulators(); static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { @@ -243,9 +259,29 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg } }); }); - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - AQPickerCommon aq_picker(aq_block_tensor); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for_product, number>{}([&](auto mIter, + auto nIter) { + if constexpr(SimpleDequant) + { + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + constexpr auto block_idx_m = tile_distributed_index{}; + constexpr auto block_idx_kq = tile_distributed_index{}; + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref += acc_val * q_block_tensor(make_tuple(block_idx_m, block_idx_kq)); + }); + } + else + { + AQPickerCommon aq_picker( + aq_block_tensor); constexpr auto tbuf_offset = number{}, @@ -273,7 +309,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; }); - }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index c44d330d13..3fb80c21ff 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -285,37 +285,66 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase "C block tensor data type!"); constexpr auto warp_size = get_warp_size(); + // Start from AQ block tensor and then scale it using BQ; this represents + // the combined A/B quantization scales for the block. + auto q_block_tensor = aq_block_tensor; + constexpr bool SimpleDequant = + Traits::NQPerBlock == 1 && + CWarpTensor::get_distributed_spans()[I0{}].impl_.size() == 0; // c_transpose + if constexpr(SimpleDequant) + { + constexpr auto aq_spans = AQBlockTensor::get_distributed_spans(); + sweep_tile_span(aq_spans[I0{}], [&](auto im) { + sweep_tile_span(aq_spans[I1{}], [&](auto ik) { + q_block_tensor(make_tuple(im, ik)) *= + bq_block_tensor(make_tuple(tile_distributed_index<0>{}, ik)); + }); + }); + } + // hot loop: - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for_product, number>{}([&](auto mIter, + auto nIter) { CWarpTensor c_warp_tensor; + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; - static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { - static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { - constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = - a_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + }); - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = - b_warp_tile_.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - if constexpr(kIterInQScale == 0) - { - c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); - } - else - { - WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } + if constexpr(SimpleDequant) + { + constexpr auto cw_spans = CWarpTensor::get_distributed_spans(); + sweep_tile_span(cw_spans[I1{}], [&](auto in) { + constexpr auto block_idx_m = tile_distributed_index{}; + constexpr auto block_idx_n = detail::make_tile_distributed_index( + merge_sequences(sequence{}, in.impl_)); + constexpr auto block_idx_kq = tile_distributed_index{}; + constexpr auto empty_idx = tile_distributed_index<>{}; + c_block_tensor(make_tuple(block_idx_m, block_idx_n)) += + c_warp_tensor(make_tuple(empty_idx, in)) * + q_block_tensor(make_tuple(block_idx_m, block_idx_kq)); }); - + } + else + { constexpr auto tbuf_offset = number{}, @@ -387,7 +416,7 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase b_scale_reg_f); }); } - }); + } }); }); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 0f3951ffcc..566f0b6153 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -101,10 +101,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName()); // clang-format on } - + /** + * @tparam nloop The number of iterations in the hot loop, + * used to normalize scheduling costs. + */ template CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { + static_assert(nloop > 0, "nloop must be greater than 0"); // Estimated number of VMEM vector loads for A per block: // total A bytes / (threads per block * vector width) constexpr index_t Aload_inst = @@ -127,12 +131,13 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Total VMEM load instructions (A + B + quant data) constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; // Approximate number of LDS reads per block - constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle / nloop; // Approximate number of LDS writes per block // (e.g., writing A from VMEM into LDS once per A load) constexpr index_t ds_write_inst = Aload_inst; // Number of MFMA instructions per wave for one block tile: - constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + constexpr index_t mfma_inst = + ((kMPerBlock / WG::kM) / nloop) * ((kNPerBlock / WG::kN) / nloop); // How often (in MFMA units) we should insert DS (LDS) operations. constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); // How often (in MFMA units) we should insert VMEM buffer loads. @@ -169,7 +174,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe } // Always mark some VALU work in the loop to reflect auxiliary scalar // or vector ALU instructions that coexist with MFMA (Blockscale calculation). - __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 4, 0); // VALU }); }); __builtin_amdgcn_sched_barrier(0); @@ -380,7 +385,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe // Prefetch A1 a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // initialize C @@ -407,7 +411,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe while(iCounter > 0) { __builtin_amdgcn_sched_barrier(0); - // Prefill A(2i+1) + // Prefill A(2i+1) ds_write a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_pong, a_block_tile_tmp); @@ -435,10 +439,14 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // prefetch Q(2i+1) aq_block_tile_2 = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile_2 = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Preload A(2i+1) ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; @@ -460,6 +468,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // prefetch Q(2i+1) aq_block_tile = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile = load_tile(bq_copy_dram_window); @@ -481,7 +491,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile_2, bq_block_tile_2, a_warp_windows_pong); - + // Preload A(2i+2) ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; @@ -521,7 +531,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile, bq_block_tile, a_warp_windows_ping); - + // Preload A ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; From b9bb1db5d932c4c0445994cfc1d37f66a3744659 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Thu, 22 Jan 2026 12:53:52 -0800 Subject: [PATCH 10/42] Addition of Stream-K tests using Tile Engine (#3514) * Addition of Stream-K tests using Tile Engine This change adds an implementation for generating Stream-K tests using Tile Engine. This will generate various test executables for different combinations based on the config files. This addition has simple tests running for bf16 and fp16, with both atomic and reduction strategies and compv3 pipeline. The tests rely on the implementation of Stream-K in Tile Engine. * integrating addition of tree reduction and editing the README * temporarily removing parallel and tree reduction from configs while bugs regarding them are being resolved --- test/ck_tile/CMakeLists.txt | 1 + .../gemm_streamk_tile_engine/CMakeLists.txt | 306 ++++++++++++++++++ .../gemm_streamk_tile_engine/README.md | 56 ++++ .../configs/simple_test_config.json | 35 ++ .../extract_test_params.py | 74 +++++ .../test_gemm_streamk_simple.cpp | 240 ++++++++++++++ .../gemm_streamk/configs/default_config.json | 2 +- .../gemm_streamk_instance_builder.py | 6 + .../gemm_streamk/gemm_streamk_profiler.hpp | 5 +- 9 files changed, 723 insertions(+), 2 deletions(-) create mode 100644 test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt create mode 100644 test/ck_tile/gemm_streamk_tile_engine/README.md create mode 100644 test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json create mode 100644 test/ck_tile/gemm_streamk_tile_engine/extract_test_params.py create mode 100644 test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 70649ed8f8..d932411991 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -41,3 +41,4 @@ add_subdirectory(fmha) add_subdirectory(gemm_tile_engine) add_subdirectory(pooling) add_subdirectory(grouped_conv) +add_subdirectory(gemm_streamk_tile_engine) diff --git a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt new file mode 100644 index 0000000000..664866d458 --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt @@ -0,0 +1,306 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# ============================================================================ +# GEMM Tile Engine Unit Tests +# +# This CMake file creates unit tests for tile_engine generated GEMM kernels. +# It follows the exact same build patterns as tile_engine for consistency +# and reliability. Each kernel configuration gets its own test executable. +# ============================================================================ + +# Locate tile_engine GEMM scripts directory +set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm_streamk") + +if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR}) + message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}") + return() +endif() + +# ============================================================================ +# create_individual_gemm_test_target +# +# Creates a single test executable for a specific kernel configuration. +# Mirrors tile_engine's create_individual_gemm_target function for consistency. +# +# Parameters: +# datatype - Data type (fp16, bf16, fp32, etc.) +# layout - Matrix layout (rcr, rrr, ccr, crr) +# config_name - Configuration file name without .json extension +# trait - Kernel trait combination string +# tile_config - Tile configuration parameters +# config_json - Full path to JSON configuration file +# ============================================================================ +function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json) + set(target_name "test_gemm_streamk_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") + + # Generated header path (already created during cmake configuration) + set(test_header "${working_path}/gemm_streamk_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + set(test_params_header "${working_path}/test_params.hpp") + + # Verify header exists (should have been generated during cmake configuration) + if(NOT EXISTS ${test_header}) + message(WARNING "Generated header not found: ${test_header}") + return() + endif() + + # Verify test parameters header exists + if(NOT EXISTS ${test_params_header}) + message(WARNING "Test parameters header not found: ${test_params_header}") + return() + endif() + + + # Create GTest executable for this kernel configuration + add_gtest_executable(${target_name} + ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_streamk_simple.cpp + ) + + # Configure GPU architectures for HIP compilation + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS}) + + # Define preprocessor macros for generated header location and test parameters + target_compile_definitions(${target_name} PRIVATE + GEMM_SINGLE_INSTANCE_HPP="${test_header}" + GEMM_TEST_PARAMS_HPP="${test_params_header}" + ) + + # Include directories for headers and dependencies + target_include_directories(${target_name} PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_BINARY_DIR}/include + ${PROJECT_SOURCE_DIR} # Root directory for tile_engine access + ${GTEST_INCLUDE_DIRS} + ) + + # Compiler options matching tile_engine requirements + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template # Suppress template warnings + -Wno-float-equal # Allow floating point comparisons + --offload-compress # Enable GPU code compression + -include ${test_header} # Auto-include generated header + ) + + # Add FP8 format definitions for proper data type interpretation + if(CK_USE_OCP_FP8) + target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8) + endif() + + message(STATUS " Created test target: ${target_name}") +endfunction() + +# ============================================================================ +# build_gemm_test_targets +# +# Builds all test targets for a specific datatype/layout/config combination. +# Uses tile_engine's two-step process: list kernels, then generate tests. +# +# Parameters: +# datatype - Data type (fp16, bf16, fp32, etc.) +# layout - Matrix layout (rcr, rrr, ccr, crr) +# config_name - Configuration file name without .json extension +# ============================================================================ +function(build_gemm_test_targets datatype layout config_name) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") + + # Locate and validate configuration file + set(config_filename "${config_name}.json") + set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}") + + if(NOT EXISTS ${json_blob}) + message(WARNING "Test config file not found: ${json_blob}") + return() + endif() + + # Prepare build directory for this configuration + file(MAKE_DIRECTORY ${working_path}) + + # STEP 1: Discovery phase - list all valid kernel configurations + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_streamk_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --list_kernels + WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} + RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error + ) + + if(NOT ret EQUAL 0) + message(WARNING "Failed to list kernels for ${datatype}_${layout}_${config_name}: ${list_error}") + return() + endif() + + # Verify kernel list file was generated + if(NOT EXISTS ${working_path}/gemm_kernel_list.txt) + message(STATUS "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") + return() + endif() + + message(STATUS "Building tests for ${datatype}_${layout}_${config_name}") + + # STEP 2a: Extract test parameters from config + set(test_params_file "${working_path}/test_params.hpp") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py + --config_file ${json_blob} + --output_file ${test_params_file} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE extract_ret + OUTPUT_VARIABLE extract_output + ERROR_VARIABLE extract_error + ) + + if(NOT extract_ret EQUAL 0) + message(WARNING "Failed to extract test parameters for ${datatype}_${layout}: ${extract_error}") + return() + endif() + + # STEP 2b: Header generation phase - generate headers using --gen_single + message(STATUS " Generating headers using --gen_single...") + + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + set(gen_count 0) + + foreach(line IN LISTS kernel_lines) + # Parse kernel specification format: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(LENGTH parts parts_len) + if(parts_len EQUAL 3) + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Generate header using --gen_single + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_streamk_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --gen_single + --kernel_name "${kernel_name}" + --tile_config "${tile_config}" + --trait_combo "${trait_combo}" + WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} + RESULT_VARIABLE gen_ret + OUTPUT_VARIABLE gen_output + ERROR_VARIABLE gen_error + ) + + if(NOT gen_ret EQUAL 0) + message(WARNING "Failed to generate header for ${kernel_name}: ${gen_error}") + else() + math(EXPR gen_count "${gen_count} + 1") + endif() + endif() + endforeach() + + message(STATUS " Generated ${gen_count} headers for ${datatype}_${layout}") + + # STEP 3: Target creation phase - create test targets + message(STATUS " Creating test targets...") + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + set(test_count 0) + foreach(line IN LISTS kernel_lines) + # Parse kernel specification format: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(LENGTH parts parts_len) + if(parts_len EQUAL 3) + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Generate test target for this kernel configuration + create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}") + math(EXPR test_count "${test_count} + 1") + endif() + endforeach() + message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}") +endfunction()# ============================================================================ +# MAIN EXECUTION - Test Target Generation +# ============================================================================ + +message(STATUS "=== Starting StreamK GEMM Tile Engine Test Configuration ===") +message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# GPU architecture filtering - only build tests for supported architectures +set(GEMM_TEST_GPU_TARGETS "") +set(DESIRED_TARGETS "gfx90a;gfx942") + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_TEST_GPU_TARGETS ${target}) + message(STATUS " Adding GPU target for tests: ${target}") + endif() +endforeach() + +# Early exit if no compatible GPU architectures are available +if(NOT GEMM_TEST_GPU_TARGETS) + message(WARNING "Skipping StreamK GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() +endif() + +message(STATUS "Building StreamK GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF) + if(ENABLE_CCACHE_TESTS) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(STATUS "Using ccache for faster test compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)") + endif() + +# ============================================================================ +# Test Configuration Matrix - Clean Focused Design +# ============================================================================ + +# All supported data types and layouts for comprehensive testing +# Note: fp64 not included (no MFMA hardware support) +set(TEST_DATATYPES "fp16;bf16") +set(TEST_LAYOUTS "rcr;rrr;ccr;crr") + +# ============================================================================ +# Test Target Generation - Datatype-Specific Categories +# ============================================================================ + +# 1. SIMPLE TEST: Test for basic functionality with data types (fp16, bf16) +# These data types can use larger warp tiles due to smaller memory footprint +set(SIMPLE_TEST_CONFIG "simple_test_config") +set(SIMPLE_TEST_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SIMPLE_TEST_CONFIG}.json") +set(SIMPLE_DATATYPES "fp16;bf16") + +if(EXISTS ${SIMPLE_TEST_CONFIG_FILE}) + message(STATUS "Processing simple test config: ${SIMPLE_TEST_CONFIG} (fp16, bf16)") + foreach(datatype IN LISTS SIMPLE_DATATYPES) + # fp16, bf16: testing all layouts (rcr, rrr, ccr, crr) + foreach(layout IN LISTS TEST_LAYOUTS) + build_gemm_test_targets("${datatype}" "${layout}" "${SIMPLE_TEST_CONFIG}") + endforeach() + endforeach() +else() + message(WARNING "Simple test config file not found: ${SIMPLE_TEST_CONFIG_FILE}") +endif() +# ============================================================================ + + +message(STATUS "StreamK GEMM tile engine tests configured with datatype-specific design:") +message(STATUS " - Simple test: fp16/bf16 (all layouts)") diff --git a/test/ck_tile/gemm_streamk_tile_engine/README.md b/test/ck_tile/gemm_streamk_tile_engine/README.md new file mode 100644 index 0000000000..4655673852 --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/README.md @@ -0,0 +1,56 @@ +# Stream-K GEMM Tile Engine Unit Tests + +## How It Works + +This unit test system integrates **tile_engine's kernel generation** into automated testing: + +1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels +2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine) +3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers +4. **Individual test executables**: Each kernel configuration becomes a separate test +5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine + +## Tile Engine Integration + +``` +JSON Config → tile_engine Python scripts → Generated Headers → Test Executables +``` + +- **`--list_kernels`**: Get available kernel configurations from JSON +- **`--gen_individual`**: Generate all kernel headers in parallel during CMake configuration +- **`--gen_single`**: Generate individual kernel header for each configuration +- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations +- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching + +### Config-Specific Test Parameters + +Each test configuration can specify optimized problem sizes in its JSON file: +- **`test_params.problem_sizes`**: Array of `{m, n, k, split_k}` configurations +- **CMake extraction**: `extract_test_params.py` generates config-specific test parameter files +- **Build integration**: Each test target uses parameters appropriate for its kernel configuration +- **Optimized testing**: Different configs test different problem sizes that showcase their strengths + + +The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure. + +## Test Configurations + +### 1. **Simple Test** (`simple_test_config.json`) +- **Purpose**: Basic functionality validation for fp16/bf16 data types +- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16 +- **Traits**: compv3 pipeline only +- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp16, bf16 + +## Data Type Support +- ✅ **fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr) +- ❌ **fp64**: Not supported (hardware MFMA limitation) +- ⏳ **fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later) + +## Test Result Behavior + +Tests automatically handle unsupported configurations through runtime validation: +- **PASSED**: Kernel executed correctly with results within error thresholds ✅ +- **SKIPPED**: Kernel validation returned "Arguments not supported" (expected for certain problem sizes/configurations) ⚠️ +- **FAILED**: Actual error or incorrect computation results ❌ + +When a kernel's `IsSupportedArgument()` check fails (e.g., due to vector alignment requirements, dimension constraints, or padding limitations), the test is automatically skipped rather than failed. This allows comprehensive testing across various problem sizes while gracefully handling configurations that don't meet specific kernel requirements. diff --git a/test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json b/test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json new file mode 100644 index 0000000000..1cfeef7570 --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/configs/simple_test_config.json @@ -0,0 +1,35 @@ +{ + "problem": { + "description": "Basic functionality validation with moderate problem sizes" + }, + "test_params": { + "problem_sizes": [ + {"m": 256, "n": 256, "k": 128, "split_k": 1}, + {"m": 512, "n": 256, "k": 256, "split_k": 1}, + {"m": 256, "n": 512, "k": 256, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [128]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [64]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false, true]}, + "reduction_strategy": {"values": ["atomic"]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_streamk_tile_engine/extract_test_params.py b/test/ck_tile/gemm_streamk_tile_engine/extract_test_params.py new file mode 100644 index 0000000000..48ec8dba83 --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/extract_test_params.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + + +import json +import argparse +import os +from pathlib import Path + + +def extract_test_params(config_file, output_file): + """Extract test parameters from config JSON and write to output file""" + + # Read config file + with open(config_file, "r") as f: + config = json.load(f) + + # Extract test parameters + test_params = [] + if "test_params" in config and "problem_sizes" in config["test_params"]: + test_params = config["test_params"]["problem_sizes"] + else: + # Default test parameters if none specified + test_params = [ + {"m": 256, "n": 256, "k": 128, "split_k": 1}, + {"m": 256, "n": 256, "k": 1024, "split_k": 1}, + {"m": 256, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 256, "k": 512, "split_k": 1}, + ] + + # Write to output file in C++ format + output_dir = Path(output_file).parent + output_dir.mkdir(parents=True, exist_ok=True) + + with open(output_file, "w") as f: + f.write("// Generated test parameters for this configuration\n") + f.write("// This file is auto-generated during CMake configuration\n\n") + f.write("static const std::vector CONFIG_TEST_PARAMS = {\n") + + for i, params in enumerate(test_params): + comma = "," if i < len(test_params) - 1 else "" + f.write( + f" {{{params['m']}, {params['n']}, {params['k']}, {params['split_k']}}}{comma}\n" + ) + + f.write("};\n") + + print( + f"Extracted {len(test_params)} test parameters from {config_file} -> {output_file}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Extract test parameters from config JSON" + ) + parser.add_argument("--config_file", required=True, help="Input config JSON file") + parser.add_argument( + "--output_file", required=True, help="Output test parameters file" + ) + + args = parser.parse_args() + + if not os.path.exists(args.config_file): + print(f"Error: Config file not found: {args.config_file}") + return 1 + + extract_test_params(args.config_file, args.output_file) + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp b/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp new file mode 100644 index 0000000000..913e7d8531 --- /dev/null +++ b/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp @@ -0,0 +1,240 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file test_gemm_simple.cpp + * @brief Unit tests for GEMM kernels generated by gemm_instance_builder + * + * This test includes kernels generated during CMake configuration by + * gemm_instance_builder.py and tests them with problem sizes extracted + * from the corresponding JSON configuration files. + */ + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp" + +// The kernel header is included via compile command line with -include flag +// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types + +// Adaptive error threshold calculation matching tile_engine's implementation +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +/// @brief Function to compare the results of the device and host computations (from tile_engine) +template +bool compare_results(std::string instanceName, + ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "For " << instanceName << " Relative error threshold is " + << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is " + << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; + + return pass; +} + +// Test parameter structure for matrix dimensions and split_k values +struct GemmTestParams +{ + int m, n, k, split_k; +}; + +// Include config-specific test parameters (after GemmTestParams struct is defined) +#ifdef GEMM_TEST_PARAMS_HPP +#include GEMM_TEST_PARAMS_HPP +#endif + +class StreamKGemmTileEngineTest : public ::testing::TestWithParam +{ + protected: + void SetUp() override + { + auto params = GetParam(); + m_ = params.m; + n_ = params.n; + k_ = params.k; + split_k_ = params.split_k; + + // Calculate strides (following tile_engine pattern) + if constexpr(std::is_same_v) + { + stride_a_ = k_; + } + else + { + stride_a_ = m_; + } + + if constexpr(std::is_same_v) + { + stride_b_ = n_; + } + else + { + stride_b_ = k_; + } + + if constexpr(std::is_same_v) + { + stride_c_ = n_; + } + else + { + stride_c_ = m_; + } + } + + // Test dimensions + int m_, n_, k_, split_k_; + int stride_a_, stride_b_, stride_c_; +}; + +TEST_P(StreamKGemmTileEngineTest, BasicFunctionality) +{ + // Get tensor layouts from generated kernel + const ALayout layout_a = ALayout{}; + const BLayout layout_b = BLayout{}; + const CLayout layout_c = CLayout{}; + + // Use split_k from test parameters + int split_k = split_k_; + int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a)); + int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b)); + int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c)); + + // Create host tensors with proper descriptors + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); + ck_tile::HostTensor c_m_n_host_result( + ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); + + // Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine) + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + + // Allocate GPU device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + // Copy data to device and zero output buffer + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + // Calculate reference result on host for verification + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_result); + + // Create GEMM kernel arguments + ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + m_, + n_, + k_, + stride_a_calc, + stride_b_calc, + stride_c_calc}; + + // Configure kernel execution for maximum speed (no timing, no debug output) + ck_tile::stream_config stream_config{nullptr, // stream + false, // time_kernel (disable timing for speed) + 0, // log_level (disable debug output) + 0, // n_warmup + 1, // n_repeat + false, // is_gpu_timer (unused when time_kernel=false) + false, // flush_cache + 1}; // rotating_count + + // Launch the generated kernel (no timing overhead for fastest execution) + try + { + SelectedKernel::launch(args, stream_config); + // Kernel launched successfully if no exception thrown + } + catch(const std::exception& e) + { + std::string error_msg(e.what()); + // If arguments not supported, skip the test (configuration validation failure, not a bug) + if(error_msg.find("Arguments not supported") != std::string::npos) + { + GTEST_SKIP() << "Configuration not supported: " << e.what(); + } + else + { + FAIL() << "Kernel launch failed: " << e.what(); + } + } + + // Copy result back from device + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + // Verify results using tile_engine's adaptive error thresholds + bool verification_passed = compare_results( + KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result); + + EXPECT_TRUE(verification_passed) << "GEMM result verification failed"; +} + +TEST_P(StreamKGemmTileEngineTest, KernelInfo) +{ + // Simple test to verify kernel information is available + EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty"; + + std::cout << "Testing kernel: " << KERNEL_NAME << std::endl; + std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_ + << std::endl; +} + +// Use config-specific test parameters (included via compile flags) +// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file +INSTANTIATE_TEST_SUITE_P(GemmVerification, + StreamKGemmTileEngineTest, + ::testing::ValuesIn(CONFIG_TEST_PARAMS), + [](const ::testing::TestParamInfo& param_info) { + return std::to_string(param_info.param.m) + "x" + + std::to_string(param_info.param.n) + "x" + + std::to_string(param_info.param.k) + "_splitk" + + std::to_string(param_info.param.split_k); + }); diff --git a/tile_engine/ops/gemm_streamk/configs/default_config.json b/tile_engine/ops/gemm_streamk/configs/default_config.json index f6b92feee3..07281bdf9a 100644 --- a/tile_engine/ops/gemm_streamk/configs/default_config.json +++ b/tile_engine/ops/gemm_streamk/configs/default_config.json @@ -98,7 +98,7 @@ }, "reduction_strategy": { "values": [ - "reduction", "atomic" + "atomic" ] } } diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index d7aaa6121a..877c803d69 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -377,6 +377,7 @@ class GemmKernelBuilder: reduction_strategy_map = { "atomic": "ck_tile::StreamKReductionStrategy::Atomic", "reduction": "ck_tile::StreamKReductionStrategy::Reduction", + "tree": "ck_tile::StreamKReductionStrategy::TreeReduction", } # Determine accumulator type based on datatype @@ -555,6 +556,11 @@ struct SelectedKernel {{ // Reset sk flags to zero before each repetition of the kernel workspace_data.SetZero(); }} + else if(reduction_strategy == ck_tile::StreamKReductionStrategy::TreeReduction) + {{ + // Reset sk flags to zero before each repetition of the kernel + workspace_data.SetZero(); + }} }}; // Launch kernel diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp index 0541116522..d168030f97 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp @@ -165,10 +165,13 @@ class GemmProfiler auto [name, avg_time] = kernel_run_result; auto dp_persistent = SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel"; + auto reduction_strategy = SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic ? "Atomic" - : "Reduction"; + : SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction + ? "Reduction" + : "TreeReduction"; KernelInstance kernel_instance{ name, dp_persistent, reduction_strategy, gemm_problem, {-1.0f, -1.0f, -1.0f}}; From eb2dc8f466cd2978490ccc3ff794d898cad9535a Mon Sep 17 00:00:00 2001 From: Vidyasagar Ananthan Date: Thu, 22 Jan 2026 14:44:47 -0800 Subject: [PATCH 11/42] Speed up glob recurse. (#3626) --- CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 54464d6809..9f1bdf8689 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -654,7 +654,9 @@ endif() -file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") +# Optimization: Search only in library/src where all instance files actually live +# (was searching entire source tree, taking ~40s instead of <1s) +file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/library/src/*/device_*_instance.cpp") file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) set(CK_DEVICE_INSTANCES) FOREACH(subdir_path ${dir_list}) From f30d04654e6bb9b064cf96c6bb4e3fff960efbd8 Mon Sep 17 00:00:00 2001 From: damien-lejeune <31985270+damien-lejeune@users.noreply.github.com> Date: Fri, 23 Jan 2026 01:06:02 +0100 Subject: [PATCH 12/42] Add missing check target in reduce tile engine op (#3631) Co-authored-by: Damien Lejeune --- tile_engine/ops/reduce/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tile_engine/ops/reduce/CMakeLists.txt b/tile_engine/ops/reduce/CMakeLists.txt index 4d5297b187..fa62890a5c 100644 --- a/tile_engine/ops/reduce/CMakeLists.txt +++ b/tile_engine/ops/reduce/CMakeLists.txt @@ -96,6 +96,7 @@ function(build_multi_reduce_for_datatype datatype variant) add_test(NAME ${test_target} COMMAND ${test_target}) set_tests_properties(${test_target} PROPERTIES LABELS "multi_reduce") + add_dependencies(check ${test_target}) endforeach() add_custom_target(test_reduce_${variant}_${datatype} DEPENDS ${codegen_blobs}) @@ -123,4 +124,4 @@ foreach(dt IN LISTS MULTI_REDUCE_DATATYPE) foreach(l IN LISTS MULTI_REDUCE_VARIANTS) build_multi_reduce_for_datatype(${dt} ${l}) endforeach() -endforeach() \ No newline at end of file +endforeach() From de5a1d730dc77d1471ad53ca18dfd7c1474e9873 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 23 Jan 2026 13:21:19 +0800 Subject: [PATCH 13/42] Revert " Fp8 block scale quantization for fmha fwd (#3330)" (#3633) This reverts commit dd0b4294afcf188f4a9154b7eea19f8e786c9539. --- CHANGELOG.md | 1 - .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 - .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 7 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 26 -- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 230 ++++------------- example/ck_tile/01_fmha/quant.hpp | 7 - .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 5 +- include/ck_tile/core/numeric/math.hpp | 7 - include/ck_tile/core/utility/functional.hpp | 12 - .../host/reference/reference_batched_gemm.hpp | 40 --- .../block_attention_quant_scale_enum.hpp | 6 - .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 242 +----------------- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 83 +----- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 83 +----- 14 files changed, 84 insertions(+), 667 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f17a4d768..54c8b776dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,6 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. -* Added FP8 block scale quantization for FMHA forward kernel. ### Changed diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index cac6671ca5..a3cfe2622a 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -77,13 +77,11 @@ def get_mask_cpp_check_expr(mask: str) -> str: QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", - "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", } QSCALE_CHECK_MAP = { "no": "quant_scale_enum::no_scale", "pertensor": "quant_scale_enum::pertensor", - "blockscale": "quant_scale_enum::blockscale", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index b59f442663..81c7b067d3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1024,7 +1024,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): # no need lse/dropout kernels for logits, qscale, mask, bias, sink in itertools.product( ["t", "f"], - ["no", "pertensor", "blockscale"], + ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"], ["f", "t"], @@ -1152,10 +1152,7 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( - ["f"], - ["no", "pertensor", "blockscale"], - get_mask_map(mask_impl).keys(), - ["no"], + ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index aedbb0e17c..fdd720fd75 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,8 +230,6 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) - const void* block_scale_seqstart_q_ptr; - const void* block_scale_seqstart_k_ptr; const void* sink_ptr; ck_tile::index_t seqlen_q; @@ -259,9 +257,6 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; - ck_tile::index_t nhead_stride_q_descale; - ck_tile::index_t nhead_stride_k_descale; - ck_tile::index_t nhead_stride_v_descale; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -269,9 +264,6 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; - ck_tile::index_t batch_stride_q_descale; - ck_tile::index_t batch_stride_k_descale; - ck_tile::index_t batch_stride_v_descale; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; @@ -284,9 +276,6 @@ struct fmha_fwd_args std::variant, std::pair> drop_seed_offset; - - ck_tile::index_t block_scale_size_q; - ck_tile::index_t block_scale_size_kv; }; struct fmha_fwd_pagedkv_args @@ -626,8 +615,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, - args.block_scale_seqstart_q_ptr, - args.block_scale_seqstart_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, @@ -647,9 +634,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, - args.nhead_stride_q_descale, - args.nhead_stride_k_descale, - args.nhead_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -658,8 +642,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.block_scale_size_q, - args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); @@ -697,9 +679,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, - args.nhead_stride_q_descale, - args.nhead_stride_k_descale, - args.nhead_stride_v_descale, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, @@ -707,9 +686,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, - args.batch_stride_q_descale, - args.batch_stride_k_descale, - args.batch_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -717,8 +693,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.block_scale_size_q, - args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index b6287245a0..0c988b2acc 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -210,11 +210,6 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { - // Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the - // compute block size - constexpr ck_tile::index_t block_scale_size_q_ = 128; - constexpr ck_tile::index_t block_scale_size_kv_ = 128; - const std::string data_type = []() { if constexpr(std::is_same_v) return "fp32"; @@ -476,11 +471,7 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - size_t i_block_scale_q = 0; - size_t i_block_scale_k = 0; - std::vector block_scale_seqstart_q_host = {0}; - std::vector block_scale_seqstart_k_host = {0}; - auto max_seqlen_k = std::numeric_limits::min(); + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -496,10 +487,6 @@ fwd_result fmha_fwd_run(mode_enum mode, { max_seqlen_k = real_seqlen_k; } - i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); - i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_); - block_scale_seqstart_q_host.push_back(i_block_scale_q); - block_scale_seqstart_k_host.push_back(i_block_scale_k); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -561,15 +548,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqstart_k_with_padding_host.back() : seqstart_k_host.back())); - const ck_tile::index_t num_block_scale_q = - (mode == mode_enum::batch) - ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_) - : i_block_scale_q; - const ck_tile::index_t num_block_scale_kv = - (mode == mode_enum::batch) - ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_) - : i_block_scale_k; - ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor sink_host({nhead}); @@ -621,18 +599,9 @@ fwd_result fmha_fwd_run(mode_enum mode, : std::array{1, 1, 1, 1, 1}); // TODO - change the tensor length for different quant scale - ck_tile::HostTensor q_descale_host( - qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead, num_block_scale_q} - : std::array{1, 1, 1}); - ck_tile::HostTensor k_descale_host( - qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead_k, num_block_scale_kv} - : std::array{1, 1, 1}); - ck_tile::HostTensor v_descale_host( - qscale.type == quant_scale_enum::blockscale - ? std::array{shape_batch, nhead_k, num_block_scale_kv} - : std::array{1, 1, 1}); + ck_tile::HostTensor q_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor k_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor v_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -748,12 +717,6 @@ fwd_result fmha_fwd_run(mode_enum mode, k_descale_host(0) = qkv_max / k_dtype_max; v_descale_host(0) = qkv_max / v_dtype_max; } - else if(qscale.type == quant_scale_enum::blockscale) - { - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); - ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); - } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); @@ -774,10 +737,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() * - sizeof(int32_t)); - ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() * - sizeof(int32_t)); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); @@ -823,8 +782,6 @@ fwd_result fmha_fwd_run(mode_enum mode, q_descale_buf.ToDevice(q_descale_host.data()); k_descale_buf.ToDevice(k_descale_host.data()); v_descale_buf.ToDevice(v_descale_host.data()); - block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data()); - block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); // Keep logical starts in seqstart_k; pass padded K via separate pointer seqstart_k.ToDevice(seqstart_k_host.data()); @@ -1018,14 +975,11 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); - const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q; - const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv; - const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv; + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = @@ -1043,9 +997,6 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); - const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead; - const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k; - const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k; // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); @@ -1133,39 +1084,9 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { - if(qscale.type == quant_scale_enum::blockscale) - { - args.q_descale_ptr = - reinterpret_cast(q_descale_buf.GetDeviceBuffer()); - args.k_descale_ptr = - reinterpret_cast(k_descale_buf.GetDeviceBuffer()); - args.v_descale_ptr = - reinterpret_cast(v_descale_buf.GetDeviceBuffer()); - - args.block_scale_seqstart_q_ptr = - (mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer() - : nullptr); - args.block_scale_seqstart_k_ptr = - (mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer() - : nullptr); - - args.nhead_stride_q_descale = nhead_stride_q_descale; - args.nhead_stride_k_descale = nhead_stride_k_descale; - args.nhead_stride_v_descale = nhead_stride_v_descale; - - args.batch_stride_q_descale = batch_stride_q_descale; - args.batch_stride_k_descale = batch_stride_k_descale; - args.batch_stride_v_descale = batch_stride_v_descale; - - args.block_scale_size_q = block_scale_size_q_; - args.block_scale_size_kv = block_scale_size_kv_; - } - else - { - args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); - args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); - args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); - } + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); args.rand_val_ptr = randval_buf.GetDeviceBuffer(); @@ -1668,42 +1589,14 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif // reference - if(qscale.type == quant_scale_enum::blockscale) - { - const ck_tile::index_t q_offset = - (mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb]; - const ck_tile::index_t k_offset = - (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; - ck_tile::reference_batched_quant_gemm( + ck_tile:: + reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, - ck_tile::idx_identity{}, - ck_tile::idx_identity{}, - [&](auto idx, auto value) { - return value * scale_s * - q_descale_host(b_idx, - std::get<0>(idx), - q_offset + std::get<1>(idx) / block_scale_size_q_) * - k_descale_host(b_idx, - std::get<0>(idx) / nr, - k_offset + std::get<2>(idx) / block_scale_size_kv_); - }); - } - else - { - ck_tile:: - reference_batched_gemm( - q_host_ref, - k_host_ref, - s_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale_s_host)); - } + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s_host)); if(0.f < logits_soft_cap) { @@ -1901,35 +1794,13 @@ fwd_result fmha_fwd_run(mode_enum mode, } } - if(qscale.type == quant_scale_enum::blockscale) - { - const ck_tile::index_t v_offset = - (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; - ck_tile:: - reference_batched_quant_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::idx_identity{}, - [&](auto idx, auto value) { - return ck_tile::type_convert(value) * - v_descale_host(b_idx, - std::get<0>(idx) / nr, - v_offset + - std::get<2>(idx) / block_scale_size_kv_); - }, - ck_tile::idx_identity{}); - } - else - { - ck_tile::reference_batched_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - oacc_element_func); - } + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off @@ -1937,6 +1808,7 @@ fwd_result fmha_fwd_run(mode_enum mode, if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on + auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err(o_host_result, o_host_ref, @@ -1994,33 +1866,31 @@ fwd_result fmha_fwd_run(mode_enum mode, if(json) { - dump_fmha_fwd_json_results( - *json, - data_type, - mode == mode_enum::batch ? "batch" : "group", - io_layout(i_perm, o_perm), - batch, - nhead, - nhead_k, - seqlen_qs[0], - seqlen_ks[0], - seqlen_kpads[0], - hdim_q, - hdim_v, - scale_s, - p_drop, - lse, - qscale.type == quant_scale_enum::no_scale - ? "no_scale" - : (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"), - bias.type == bias_enum::elementwise_bias - ? "elementwise_bias" - : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - is_v_rowmajor ? "r" : "c", - pass, - ave_time, - tflops, - gb_per_sec); + dump_fmha_fwd_json_results(*json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + qscale.type == quant_scale_enum::no_scale ? "no_scale" + : "pertensor", + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); } return pass ? fwd_result::success : fwd_result::failure; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index feb28cba24..59d4ac1707 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -13,7 +13,6 @@ enum class quant_scale_enum { no_scale = 0, pertensor = 1, - blockscale, }; struct quant_scale_info @@ -26,8 +25,6 @@ struct quant_scale_info os << "n"; else if(type == quant_scale_enum::pertensor) os << "pt"; - else if(type == quant_scale_enum::blockscale) - os << "bs"; } static quant_scale_info decode(std::string str) @@ -41,10 +38,6 @@ struct quant_scale_info { info.type = quant_scale_enum::pertensor; } - else if(str == "bs" || str == "2") - { - info.type = quant_scale_enum::blockscale; - } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 227f26c8f3..596542eb9d 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -95,11 +95,10 @@ run_fp8bf16_tests() { for perm in 0 1 ; do for b in 1 2 ; do for hdim in 64 128 256 ; do - for scale in 1 2; do - $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS - done ; done ; done ; done + done ; done ; done } run_fp8fp32_tests() { diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index a46ae509dd..96e76f669d 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -37,13 +37,6 @@ struct scales return lhs_ * rhs; } - template - CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const - { - auto new_scale = lhs_ * other; - return scales>(new_scale); - } - private: Scale lhs_; }; diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index aa4bfa3f15..898d21574e 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -119,18 +119,6 @@ struct identity } }; -// Similar to identity, but takes an additional index parameter as the first argument. -// The index is ignored and only the second argument (value) is forwarded. -// Useful for indexed element-wise operations where the functor signature requires an index. -struct idx_identity -{ - template - CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept - { - return std::forward(arg); - } -}; - namespace detail { // RemainLengths: sequence<...> diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index d742426740..63f13b1b16 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -47,44 +47,4 @@ CK_TILE_HOST void reference_batched_gemm(const HostTensor& a_b_m_k, make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } -template -CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, - const HostTensor& b_b_n_k, - HostTensor& c_b_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) -{ - const int N = b_b_n_k.mDesc.get_lengths()[1]; - const int K = b_b_n_k.mDesc.get_lengths()[2]; - - auto f = [&](auto batch, auto m) { - for(int n = 0; n < N; ++n) - { - AccDataType v_acc = 0; - - for(int k = 0; k < K; ++k) - { - AccDataType v_a = ck_tile::type_convert( - a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k))); - AccDataType v_b = ck_tile::type_convert( - b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k))); - - v_acc += v_a * v_b; - } - - c_b_m_n(batch, m, n) = ck_tile::type_convert( - acc_element_op(std::make_tuple(batch, m, n), v_acc)); - } - }; - - make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( - std::thread::hardware_concurrency()); -} } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 7e0f704bef..3755a2bc71 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -12,7 +12,6 @@ enum class BlockAttentionQuantScaleEnum { NO_SCALE = 0, PERTENSOR = 1, - BLOCKSCALE, }; template @@ -28,10 +27,5 @@ struct BlockAttentionQuantScaleEnumToStr -struct BlockAttentionQuantScaleEnumToStr -{ - static constexpr const char* name = "blockscale"; -}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0039c57cfc..adbedc5259 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -168,29 +168,6 @@ struct FmhaFwdKernel const void* v_descale_ptr = nullptr; }; - struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs - { - ck_tile::index_t nhead_stride_q_descale; - ck_tile::index_t nhead_stride_k_descale; - ck_tile::index_t nhead_stride_v_descale; - - ck_tile::index_t block_scale_size_q; - ck_tile::index_t block_scale_size_kv; - }; - - struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs - { - ck_tile::index_t batch_stride_q_descale; - ck_tile::index_t batch_stride_k_descale; - ck_tile::index_t batch_stride_v_descale; - }; - - struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs - { - const int32_t* block_scale_seqstart_q_ptr; - const int32_t* block_scale_seqstart_k_ptr; - }; - struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -266,12 +243,9 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t< - QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, - FmhaFwdCommonQScaleKargs, - std::conditional_t>>, + std::conditional_t>, std::conditional_t>, std::conditional_t> { @@ -295,12 +269,9 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t< - QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, - FmhaFwdCommonQScaleKargs, - std::conditional_t>>, + std::conditional_t>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -357,9 +328,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -367,9 +335,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, - ck_tile::index_t batch_stride_q_descale, - ck_tile::index_t batch_stride_k_descale, - ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -378,8 +343,6 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -450,23 +413,6 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - kargs.q_descale_ptr = q_descale_ptr; - kargs.k_descale_ptr = k_descale_ptr; - kargs.v_descale_ptr = v_descale_ptr; - - kargs.nhead_stride_q_descale = nhead_stride_q_descale; - kargs.nhead_stride_k_descale = nhead_stride_k_descale; - kargs.nhead_stride_v_descale = nhead_stride_v_descale; - - kargs.batch_stride_q_descale = batch_stride_q_descale; - kargs.batch_stride_k_descale = batch_stride_k_descale; - kargs.batch_stride_v_descale = batch_stride_v_descale; - - kargs.block_scale_size_q = block_scale_size_q; - kargs.block_scale_size_kv = block_scale_size_kv; - } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -532,9 +478,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -542,9 +485,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, - ck_tile::index_t batch_stride_q_descale, - ck_tile::index_t batch_stride_k_descale, - ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -552,8 +492,6 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -590,9 +528,6 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, - nhead_stride_q_descale, - nhead_stride_k_descale, - nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -600,9 +535,6 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, - batch_stride_q_descale, - batch_stride_k_descale, - batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -610,8 +542,6 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_size_q, - block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -651,9 +581,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -661,9 +588,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, - ck_tile::index_t batch_stride_q_descale, - ck_tile::index_t batch_stride_k_descale, - ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -671,8 +595,6 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -709,9 +631,6 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, - nhead_stride_q_descale, - nhead_stride_k_descale, - nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -719,9 +638,6 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, - batch_stride_q_descale, - batch_stride_k_descale, - batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -729,8 +645,6 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_size_q, - block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -752,8 +666,6 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, - const void* block_scale_seqstart_q_ptr, - const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -773,9 +685,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -785,8 +694,6 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -856,24 +763,6 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - kargs.q_descale_ptr = q_descale_ptr; - kargs.k_descale_ptr = k_descale_ptr; - kargs.v_descale_ptr = v_descale_ptr; - - kargs.nhead_stride_q_descale = nhead_stride_q_descale; - kargs.nhead_stride_k_descale = nhead_stride_k_descale; - kargs.nhead_stride_v_descale = nhead_stride_v_descale; - - kargs.block_scale_size_q = block_scale_size_q; - kargs.block_scale_size_kv = block_scale_size_kv; - - kargs.block_scale_seqstart_q_ptr = - reinterpret_cast(block_scale_seqstart_q_ptr); - kargs.block_scale_seqstart_k_ptr = - reinterpret_cast(block_scale_seqstart_k_ptr); - } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -925,8 +814,6 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, - const void* block_scale_seqstart_q_ptr, - const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -946,9 +833,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -957,8 +841,6 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -978,8 +860,6 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, - block_scale_seqstart_q_ptr, - block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -999,9 +879,6 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, - nhead_stride_q_descale, - nhead_stride_k_descale, - nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -1010,8 +887,6 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_size_q, - block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -1034,8 +909,6 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, - const void* block_scale_seqstart_q_ptr, - const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -1055,9 +928,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_q_descale, - ck_tile::index_t nhead_stride_k_descale, - ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -1066,8 +936,6 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - ck_tile::index_t block_scale_size_q, - ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -1087,8 +955,6 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, - block_scale_seqstart_q_ptr, - block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -1108,9 +974,6 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, - nhead_stride_q_descale, - nhead_stride_k_descale, - nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -1119,8 +982,6 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - block_scale_size_q, - block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -1250,16 +1111,13 @@ struct FmhaFwdKernel const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; - long_index_t batch_offset_q_descale = 0; - long_index_t batch_offset_k_descale = 0; - long_index_t batch_offset_v_descale = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s @@ -1295,14 +1153,6 @@ struct FmhaFwdKernel { batch_offset_randval = query_start * kargs.stride_randval; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch]; - const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch]; - batch_offset_q_descale = bquery_start; - batch_offset_k_descale = bkey_start; - batch_offset_v_descale = bkey_start; - } batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) @@ -1370,15 +1220,6 @@ struct FmhaFwdKernel batch_offset_randval = static_cast(i_batch) * kargs.batch_stride_randval; } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - batch_offset_q_descale = - static_cast(i_batch) * kargs.batch_stride_q_descale; - batch_offset_k_descale = - static_cast(i_batch) * kargs.batch_stride_k_descale; - batch_offset_v_descale = - static_cast(i_batch) * kargs.batch_stride_v_descale; - } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; // If cumulative seqlen pointers are provided, override per-batch effective lengths @@ -1699,8 +1540,7 @@ struct FmhaFwdKernel }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - - auto o_acc_tile = [&, i_nhead_ = i_nhead]() { + auto o_acc_tile = [&]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline @@ -1741,62 +1581,8 @@ struct FmhaFwdKernel block_indices, smem_ptr, dropout, - nullptr, - nullptr, - 1, sink_value); } - else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - const float* q_descale_ptr = - reinterpret_cast(kargs.q_descale_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_q_descale + - batch_offset_q_descale; - const float* k_descale_ptr = - reinterpret_cast(kargs.k_descale_ptr) + - static_cast(i_nhead_ / kargs.nhead_ratio_qk) * - kargs.nhead_stride_k_descale + - batch_offset_k_descale; - const float* v_descale_ptr = - reinterpret_cast(kargs.v_descale_ptr) + - static_cast(i_nhead_ / kargs.nhead_ratio_qk) * - kargs.nhead_stride_v_descale + - batch_offset_v_descale; - - size_t idx = i_m0 / kargs.block_scale_size_q; - float q_descale = q_descale_ptr[idx]; - // BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8 - // Both P and rowsum are scaled by 2^shift, canceling in normalization - // No additional scaling needed in p_compute_element_func or o_acc_element_func - - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - scales(q_descale), // s_acc_element_func - identity{}, // p_compute_element_func - No scaling (done in exp2) - identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum) - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout, - k_descale_ptr, - v_descale_ptr, - kargs.block_scale_size_kv, - sink_value); - } else { return FmhaPipeline{}(q_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 2fbc9fdb54..dcccdf541c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -57,13 +57,8 @@ struct BlockFmhaPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kHasSink = Problem::kHasSink; - // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] - static constexpr float OCP_FP8_SHIFT = 8.0f; - static constexpr float FNUZ_FP8_SHIFT = 7.0f; - static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) @@ -172,9 +167,6 @@ struct BlockFmhaPipelineQRKSVS const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, - const float* k_descale_ptr, - const float* v_descale_ptr, - const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -366,13 +358,6 @@ struct BlockFmhaPipelineQRKSVS static_assert(1 <= k1_loops); do { - float k_descale = 1.0f; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - // K and V share the same seqlen_k position within a block - const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; - k_descale = k_descale_ptr[kv_idx]; - } // STAGE 1, QK gemm auto k_dram_window = make_tile_window( k_dram_block_window.get_bottom_tensor_view(), @@ -442,20 +427,11 @@ struct BlockFmhaPipelineQRKSVS k_lds_window); schedule_gemm0(); } - // dequant - auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - return s_acc_element_func * k_descale; - } - else - return s_acc_element_func; - }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -473,7 +449,7 @@ struct BlockFmhaPipelineQRKSVS { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -490,7 +466,7 @@ struct BlockFmhaPipelineQRKSVS } else { - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -595,21 +571,7 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - // For BLOCKSCALE: precompute (m - shift) once per row - // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) - // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) - auto validated_m = get_validated_m(m[i_idx]); - auto row_max = scale_s * validated_m; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { -#if CK_TILE_USE_OCP_FP8 - validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap - row_max -= OCP_FP8_SHIFT; // for else branch -#else - validated_m -= FNUZ_FP8_SHIFT; - row_max -= FNUZ_FP8_SHIFT; -#endif - } + auto row_max = scale_s * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -617,13 +579,13 @@ struct BlockFmhaPipelineQRKSVS if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { @@ -714,39 +676,18 @@ struct BlockFmhaPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } - move_tile_window(v_dram_window, {0, kK1}); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); - float v_descale = 1.0f; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - // K and V share the same seqlen_k position within a block - const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; - v_descale = v_descale_ptr[kv_idx]; - } // STAGE 3, KV gemm - auto o_acc0 = decltype(o_acc){}; - clear_tile(o_acc0); - - auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - return o_acc0; - } - else - { - return o_acc; - } - }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); - gemm_1(o_acc_, + gemm_1(o_acc, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_window); @@ -781,16 +722,11 @@ struct BlockFmhaPipelineQRKSVS // tail { block_sync_lds(); - gemm_1(o_acc_, + gemm_1(o_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), v_lds_window); block_sync_lds(); } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - tile_elementwise_inout( - [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); - } } while(++i_total_loops < num_total_loop); // store lse @@ -910,9 +846,6 @@ struct BlockFmhaPipelineQRKSVS block_indices, smem_ptr, dropout, - nullptr, - nullptr, - 1, sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 81bd8d5ab5..797e572d58 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -46,7 +46,6 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static constexpr auto QScaleEnum = Problem::QScaleEnum; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); @@ -65,10 +64,6 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasSink = Problem::kHasSink; - // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] - static constexpr float OCP_FP8_SHIFT = 8.0f; - static constexpr float FNUZ_FP8_SHIFT = 7.0f; - static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || !kHasLogitsSoftCap)) || @@ -195,9 +190,6 @@ struct BlockFmhaPipelineQRKSVSAsync const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, - const float* k_descale_ptr, - const float* v_descale_ptr, - const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -411,13 +403,6 @@ struct BlockFmhaPipelineQRKSVSAsync // main loop do { - float k_descale = 1.0f; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - // K and V share the same seqlen_k position within a block - const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; - k_descale = k_descale_ptr[kv_idx]; - } // STAGE 1, QK gemm clear_tile(s_acc); // initialize C if constexpr(k0_loops > 1) @@ -464,20 +449,11 @@ struct BlockFmhaPipelineQRKSVSAsync sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); } __builtin_amdgcn_sched_barrier(1); - // dequant - auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - return s_acc_element_func * k_descale; - } - else - return s_acc_element_func; - }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -495,7 +471,7 @@ struct BlockFmhaPipelineQRKSVSAsync { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -512,7 +488,7 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -654,21 +630,7 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - // For BLOCKSCALE: precompute (m - shift) once per row - // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) - // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) - auto validated_m = get_validated_m(m[i_idx]); - auto row_max = scale_s * validated_m; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { -#if CK_TILE_USE_OCP_FP8 - validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap - row_max -= OCP_FP8_SHIFT; // for else branch -#else - validated_m -= FNUZ_FP8_SHIFT; - row_max -= FNUZ_FP8_SHIFT; -#endif - } + auto row_max = scale_s * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -676,13 +638,13 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { @@ -773,27 +735,7 @@ struct BlockFmhaPipelineQRKSVSAsync #endif }(); - float v_descale = 1.0f; - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - // K and V share the same seqlen_k position within a block - const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; - v_descale = v_descale_ptr[kv_idx]; - } // STAGE 3, KV gemm - auto o_acc0 = decltype(o_acc){}; - clear_tile(o_acc0); - - auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - return o_acc0; - } - else - { - return o_acc; - } - }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { @@ -803,7 +745,7 @@ struct BlockFmhaPipelineQRKSVSAsync v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf } block_sync_lds(); - gemm_1(o_acc_, + gemm_1(o_acc, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), get_slice_tile( @@ -866,19 +808,13 @@ struct BlockFmhaPipelineQRKSVSAsync { block_sync_lds(); gemm_1( - o_acc_, + o_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), get_slice_tile( v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); } - - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) - { - tile_elementwise_inout( - [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); - } } while(i_total_loops < num_total_loop); // store lse @@ -986,9 +922,6 @@ struct BlockFmhaPipelineQRKSVSAsync block_indices, smem_ptr, dropout, - nullptr, - nullptr, - 1, sink_v); } }; From 7b3db1a878181004fc5db7cdb82840623beaadb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 23 Jan 2026 10:29:59 +0100 Subject: [PATCH 14/42] Grouped conv fwd direct load vector=2 (#3632) --- .../device_grouped_conv_fwd_xdl_mem_instance.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp index 838b14bf8e..e1b84d97b1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp @@ -76,7 +76,8 @@ using device_grouped_conv_fwd_xdl_bf16_direct_load_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, BF16, BF16, true>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, BF16, BF16, true>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, BF16, BF16, true>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, BF16, BF16, true> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, BF16, BF16, true> // clang-format on >; @@ -106,7 +107,8 @@ using device_grouped_conv_fwd_xdl_f16_direct_load_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F16, F16, true>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, F16, F16, true>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, F16, F16, true>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F16, F16, true> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F16, F16, true> // clang-format on >; From 81ee19bd2c9328001b8071647cdb6bdca8a4c5f6 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Fri, 23 Jan 2026 12:19:51 +0100 Subject: [PATCH 15/42] WMMA grouped conv fwd large tensor extra flavors (#3582) * Additional flavors for WMMA conv fwd large tensor - added F16/BF16 clamp operation - added F16/BF16 bias_clamp operation - small modification to the device code to accomodate extra tensors * changed strategy to handle GemmArgs array * Adding generic instance * Added generic instance to clamp and bias_clamp ops --- ...ltiple_d_wmma_cshuffle_v3_large_tensor.hpp | 83 +++++---- include/ck/utility/array.hpp | 11 ++ ..._wmma_cshufflev3_large_tensor_instance.hpp | 52 +++++- ...grouped_convolution_forward_bias_clamp.hpp | 24 ++- ...ion_forward_bias_clamp_wmma_cshufflev3.inc | 164 ++++++++++++------ .../gpu/grouped_convolution_forward_clamp.hpp | 24 ++- ...volution_forward_clamp_wmma_cshufflev3.inc | 164 ++++++++++++------ .../CMakeLists.txt | 4 + ...hwgc_gkyxc_nhwgk_bf16_generic_instance.cpp | 40 +++++ ...tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 40 +++++ ...nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp | 40 +++++ ..._tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 40 +++++ .../grouped_conv2d_fwd_clamp/CMakeLists.txt | 4 + ...hwgc_gkyxc_nhwgk_bf16_generic_instance.cpp | 40 +++++ ...tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 40 +++++ ...nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp | 40 +++++ ..._tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 40 +++++ .../CMakeLists.txt | 4 + ...gc_gkzyxc_ndhwgk_bf16_generic_instance.cpp | 40 +++++ ...sor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 40 +++++ ...wgc_gkzyxc_ndhwgk_f16_generic_instance.cpp | 40 +++++ ...nsor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 40 +++++ .../grouped_conv3d_fwd_clamp/CMakeLists.txt | 4 + ...gc_gkzyxc_ndhwgk_bf16_generic_instance.cpp | 40 +++++ ...sor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 40 +++++ ...wgc_gkzyxc_ndhwgk_f16_generic_instance.cpp | 40 +++++ ...nsor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 40 +++++ 27 files changed, 1007 insertions(+), 171 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp index 08d0f296f0..ed0ead42d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp @@ -617,32 +617,32 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor const auto m_block = GridwiseGemm::CalculateMBlock(gemm_m); const auto n_block = GridwiseGemm::CalculateNBlock(gemm_n); - GemmArgs new_args{}; - new_args.a_ptrs_ = p_as_grid; - new_args.b_ptrs_ = p_bs_grid; - new_args.ds_ptrs_ = p_ds_grid; - new_args.e_ptr_ = p_e_grid; - - new_args.a_element_op_ = a_element_op_; - new_args.b_element_op_ = b_element_op_; - new_args.cde_element_op_ = cde_element_op_; - - new_args.M_ = gemm_m; - new_args.N_ = gemm_n; - - new_args.a_grid_desc_ = a_grid_desc; - new_args.b_grid_desc_ = b_grid_desc; - new_args.ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + const auto ds_desc_mblock_mperblock_nblock_nperblock = GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, m_block, n_block); - new_args.e_grid_desc_mblock_mperblock_nblock_nperblock_ = + const auto e_desc_mblock_mperblock_nblock_nperblock = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, m_block, n_block); - new_args.BlockStart_ = BlockStart; - new_args.BlockEnd_ = BlockEnd; - - gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args; + gemm_desc_kernel_args_.Emplace( + valid_gemms_count_, + GemmArgs{.a_ptrs_ = p_as_grid, + .b_ptrs_ = p_bs_grid, + .ds_ptrs_ = p_ds_grid, + .e_ptr_ = p_e_grid, + .a_element_op_ = a_element_op_, + .b_element_op_ = b_element_op_, + .cde_element_op_ = cde_element_op_, + .M_ = gemm_m, + .N_ = gemm_n, + .a_grid_desc_ = a_grid_desc, + .b_grid_desc_ = b_grid_desc, + .ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + ds_desc_mblock_mperblock_nblock_nperblock, + .e_grid_desc_mblock_mperblock_nblock_nperblock_ = + e_desc_mblock_mperblock_nblock_nperblock, + .BlockStart_ = BlockStart, + .BlockEnd_ = BlockEnd}); valid_gemms_count_++; } @@ -789,11 +789,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; - static_for<0, NumDTensor, 1>{}([&](auto i) { - compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; - compute_ptr_offset_of_n_.BatchStrideDs_(i) = - ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; - }); + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + }); + } } void Print() const @@ -807,12 +810,15 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor << ", is_split_valid=" << std::boolalpha << is_split_valid_ << std::noboolalpha << ", grid_size=" << grid_size_ << std::endl; - static_for<0, NumDTensor, 1>{}([&](auto i) { - std::cout << " Ds[" << i.value - << "] group stride=" << compute_ptr_offset_of_groups_.BatchStrideDs_(i) - << ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_(i) - << std::endl; - }); + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << " Ds[" << i.value << "] group stride=" + << compute_ptr_offset_of_groups_.BatchStrideDs_.At(i) + << ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_.At(i) + << std::endl; + }); + } std::cout << "===== GEMM splits =====" << std::endl; for(index_t i = 0; i < valid_gemms_count_; ++i) @@ -836,11 +842,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor std::cout << " E[MBlock, MPerBlock, NBlock, NPerBlock]: " << gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_ << std::endl; - static_for<0, NumDTensor, 1>{}([&](auto d_idx) { - std::cout << " D" << d_idx.value << " descriptor: " - << gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_(d_idx) - << std::endl; - }); + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto d_idx) { + std::cout << " D" << d_idx.value << " descriptor: " + << gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_.At(d_idx) + << std::endl; + }); + } } } diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 2b249884b6..73eb18fe16 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -6,6 +6,8 @@ #include "functional2.hpp" #include "sequence.hpp" +#include +#include namespace ck { @@ -27,6 +29,15 @@ struct Array __host__ __device__ constexpr TData& operator()(index_t i) { return At(i); } + template + __host__ constexpr auto Emplace(index_t i, Args&&... args) + -> std::enable_if_t> + { + assert(i >= 0 && i < NSize); + mData[i].~TData(); + new(mData + i) TData(ck::forward(args)...); + } + template __host__ __device__ constexpr auto operator=(const T& a) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp index c3769fbfd0..9c9e95101e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp @@ -29,12 +29,32 @@ using S = ck::Sequence; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; +using Clamp = ck::tensor_operation::element_wise::Clamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +template +using device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances = std::tuple< + // clang-format off + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template using device_grouped_conv_fwd_wmma_large_tensor_f16_instances = std::tuple< // clang-format off - //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, @@ -56,6 +77,24 @@ using device_grouped_conv_fwd_wmma_large_tensor_f16_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances = std::tuple< + // clang-format off + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template using device_grouped_conv_fwd_wmma_large_tensor_bf16_instances = std::tuple< // clang-format off - //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index aa83fe8155..b3b36a657e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -293,8 +293,10 @@ struct DeviceOperationInstanceFactory>>& instances); -// void -// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( -// std::vector, -// NHWGK, -// BF16, -// BF16, -// Tuple, -// BF16, -// PassThrough, -// PassThrough, -// AddClamp>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances); void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( std::vector>>& instances); -// void -// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( -// std::vector, -// NDHWGK, -// BF16, -// BF16, -// Tuple, -// BF16, -// PassThrough, -// PassThrough, -// AddClamp>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances); #endif @@ -203,20 +229,33 @@ void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_ PassThrough, AddClamp>>>& instances); -// void -// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( -// std::vector, -// NHWGK, -// F16, -// F16, -// Tuple, -// F16, -// PassThrough, -// PassThrough, -// AddClamp>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances); void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1( std::vector>>& instances); -// void -// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( -// std::vector, -// NDHWGK, -// F16, -// F16, -// Tuple, -// F16, -// PassThrough, -// PassThrough, -// AddClamp>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances); #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index a91d885d8c..5bad3b7c4f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -290,8 +290,10 @@ struct DeviceOperationInstanceFactory>>& instances); -// void -// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( -// std::vector, -// NHWGK, -// BF16, -// BF16, -// Tuple<>, -// BF16, -// PassThrough, -// PassThrough, -// Clamp>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( std::vector>>& instances); -// void -// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( -// std::vector, -// NDHWGK, -// BF16, -// BF16, -// Tuple<>, -// BF16, -// PassThrough, -// PassThrough, -// Clamp>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances); #endif @@ -256,35 +282,61 @@ void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f1 PassThrough, Clamp>>>& instances); -// void -// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( -// std::vector, -// NHWGK, -// F16, -// F16, -// Tuple<>, -// F16, -// PassThrough, -// PassThrough, -// Clamp>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances); -// void -// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( -// std::vector, -// NDHWGK, -// F16, -// F16, -// Tuple<>, -// F16, -// PassThrough, -// PassThrough, -// Clamp>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances); + +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances); #endif diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt index 4f9c5d7a96..0023e15edf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -49,4 +49,8 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part2.cpp wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part3.cpp wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp new file mode 100644 index 0000000000..febf4d509a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..8430bd71b2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp new file mode 100644 index 0000000000..133ae48939 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..65ce350ccc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt index 3ba23f9384..1a091f1a4e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt @@ -49,4 +49,8 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part2.cpp wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part3.cpp wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp new file mode 100644 index 0000000000..6dfae833e2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..43447ea826 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp new file mode 100644 index 0000000000..28a398512a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..8c2a0578ea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector, + NHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index d8b468931f..54e0a1392f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -44,6 +44,10 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp ) add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp new file mode 100644 index 0000000000..5cd12e6ede --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..5c6fa0c011 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp new file mode 100644 index 0000000000..5c6102bde3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..977f930b38 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + AddClamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 4aa4de8bc0..29d24f1d28 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -44,6 +44,10 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp ) add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp new file mode 100644 index 0000000000..6c4f89177f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..f4ff4ad0a1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp new file mode 100644 index 0000000000..98313be0e8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..90838fe41a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Clamp>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 2e08a7e5ab51b020c90008b45c75dc35c2ba426c Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Fri, 23 Jan 2026 12:20:00 +0100 Subject: [PATCH 16/42] WMMA grouped conv fwd large tensor bias bnorm clamp (#3595) * Added bias_bnorm_clamp for WMMA conv fwd large tensor. Following operations are added for FP16/BF16 data type and NHWGCxGKYXC layout. - grouped_conv2d_fwd_bias_bnorm_clamp - grouped_conv3d_fwd_bias_bnorm_clamp * changed strategy to handle GemmArgs array * Adding generic instance * fixed last nits from reviewers and copilot --- ...d_convolution_forward_bias_bnorm_clamp.hpp | 12 ++++ ...rward_bias_bnorm_clamp_wmma_cshufflev3.inc | 60 +++++++++++++++++++ .../CMakeLists.txt | 2 + ...tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 44 ++++++++++++++ ..._tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 44 ++++++++++++++ .../CMakeLists.txt | 2 + ...sor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 44 ++++++++++++++ ...nsor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 44 ++++++++++++++ 8 files changed, 252 insertions(+) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp index 295b2c21b5..e42a3f2045 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp @@ -297,6 +297,9 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -56,6 +86,21 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhw PassThrough, BiasNormalizeInInferClamp>>>& instances); +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index d089663f37..1f381f5f7d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -328,6 +328,8 @@ generate_sharded_instantiations( add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP} ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..f2729fe0e4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..7be4be2f1e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index dc759cbb54..f54588991f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -309,6 +309,8 @@ generate_sharded_instantiations( add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp ${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP} ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..4a9c68b2d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..92c86b8df0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 67f0b74ec6687192fac14c359c57aca237d3cf2a Mon Sep 17 00:00:00 2001 From: ltqin Date: Sat, 24 Jan 2026 01:03:22 +0800 Subject: [PATCH 17/42] Revert "Revert " Fp8 block scale quantization for fmha fwd (#3330)" (#3633)" (#3635) This reverts commit de5a1d730dc77d1471ad53ca18dfd7c1474e9873. Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- CHANGELOG.md | 1 + .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 7 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 26 ++ example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 230 +++++++++++++---- example/ck_tile/01_fmha/quant.hpp | 7 + .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 5 +- include/ck_tile/core/numeric/math.hpp | 7 + include/ck_tile/core/utility/functional.hpp | 12 + .../host/reference/reference_batched_gemm.hpp | 40 +++ .../block_attention_quant_scale_enum.hpp | 6 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 242 +++++++++++++++++- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 83 +++++- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 83 +++++- 14 files changed, 667 insertions(+), 84 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54c8b776dd..5f17a4d768 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. * Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. * Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming. +* Added FP8 block scale quantization for FMHA forward kernel. ### Changed diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index a3cfe2622a..cac6671ca5 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -77,11 +77,13 @@ def get_mask_cpp_check_expr(mask: str) -> str: QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", + "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", } QSCALE_CHECK_MAP = { "no": "quant_scale_enum::no_scale", "pertensor": "quant_scale_enum::pertensor", + "blockscale": "quant_scale_enum::blockscale", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 81c7b067d3..b59f442663 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1024,7 +1024,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): # no need lse/dropout kernels for logits, qscale, mask, bias, sink in itertools.product( ["t", "f"], - ["no", "pertensor"], + ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"], ["f", "t"], @@ -1152,7 +1152,10 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( - ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] + ["f"], + ["no", "pertensor", "blockscale"], + get_mask_map(mask_impl).keys(), + ["no"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index fdd720fd75..aedbb0e17c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,6 +230,8 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) + const void* block_scale_seqstart_q_ptr; + const void* block_scale_seqstart_k_ptr; const void* sink_ptr; ck_tile::index_t seqlen_q; @@ -257,6 +259,9 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -264,6 +269,9 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; @@ -276,6 +284,9 @@ struct fmha_fwd_args std::variant, std::pair> drop_seed_offset; + + ck_tile::index_t block_scale_size_q; + ck_tile::index_t block_scale_size_kv; }; struct fmha_fwd_pagedkv_args @@ -615,6 +626,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqstart_k_ptr, args.seqlen_q_ptr, args.seqlen_k_ptr, + args.block_scale_seqstart_q_ptr, + args.block_scale_seqstart_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, @@ -634,6 +647,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -642,6 +658,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_size_q, + args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); @@ -679,6 +697,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, @@ -686,6 +707,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, + args.batch_stride_q_descale, + args.batch_stride_k_descale, + args.batch_stride_v_descale, args.window_size_left, args.window_size_right, args.sink_size, @@ -693,6 +717,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_size_q, + args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 0c988b2acc..b6287245a0 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -210,6 +210,11 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { + // Note: block_scale_size_q_ and block_scale_size_kv_ should be greater than or equal to the + // compute block size + constexpr ck_tile::index_t block_scale_size_q_ = 128; + constexpr ck_tile::index_t block_scale_size_kv_ = 128; + const std::string data_type = []() { if constexpr(std::is_same_v) return "fp32"; @@ -471,7 +476,11 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - auto max_seqlen_k = std::numeric_limits::min(); + size_t i_block_scale_q = 0; + size_t i_block_scale_k = 0; + std::vector block_scale_seqstart_q_host = {0}; + std::vector block_scale_seqstart_k_host = {0}; + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -487,6 +496,10 @@ fwd_result fmha_fwd_run(mode_enum mode, { max_seqlen_k = real_seqlen_k; } + i_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_size_q_); + i_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_size_kv_); + block_scale_seqstart_q_host.push_back(i_block_scale_q); + block_scale_seqstart_k_host.push_back(i_block_scale_k); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -548,6 +561,15 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqstart_k_with_padding_host.back() : seqstart_k_host.back())); + const ck_tile::index_t num_block_scale_q = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_size_q_) + : i_block_scale_q; + const ck_tile::index_t num_block_scale_kv = + (mode == mode_enum::batch) + ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_size_kv_) + : i_block_scale_k; + ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor sink_host({nhead}); @@ -599,9 +621,18 @@ fwd_result fmha_fwd_run(mode_enum mode, : std::array{1, 1, 1, 1, 1}); // TODO - change the tensor length for different quant scale - ck_tile::HostTensor q_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor k_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor v_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor q_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead, num_block_scale_q} + : std::array{1, 1, 1}); + ck_tile::HostTensor k_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead_k, num_block_scale_kv} + : std::array{1, 1, 1}); + ck_tile::HostTensor v_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead_k, num_block_scale_kv} + : std::array{1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -717,6 +748,12 @@ fwd_result fmha_fwd_run(mode_enum mode, k_descale_host(0) = qkv_max / k_dtype_max; v_descale_host(0) = qkv_max / v_dtype_max; } + else if(qscale.type == quant_scale_enum::blockscale) + { + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); + ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); + } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); @@ -737,6 +774,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_scale_seqstart_q_buf(block_scale_seqstart_q_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem block_scale_seqstart_k_buf(block_scale_seqstart_k_host.size() * + sizeof(int32_t)); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); @@ -782,6 +823,8 @@ fwd_result fmha_fwd_run(mode_enum mode, q_descale_buf.ToDevice(q_descale_host.data()); k_descale_buf.ToDevice(k_descale_host.data()); v_descale_buf.ToDevice(v_descale_host.data()); + block_scale_seqstart_q_buf.ToDevice(block_scale_seqstart_q_host.data()); + block_scale_seqstart_k_buf.ToDevice(block_scale_seqstart_k_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); // Keep logical starts in seqstart_k; pass padded K via separate pointer seqstart_k.ToDevice(seqstart_k_host.data()); @@ -975,11 +1018,14 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); - const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_q_descale = num_block_scale_q; + const ck_tile::index_t nhead_stride_k_descale = num_block_scale_kv; + const ck_tile::index_t nhead_stride_v_descale = num_block_scale_kv; // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = @@ -997,6 +1043,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + const ck_tile::index_t batch_stride_q_descale = num_block_scale_q * nhead; + const ck_tile::index_t batch_stride_k_descale = num_block_scale_kv * nhead_k; + const ck_tile::index_t batch_stride_v_descale = num_block_scale_kv * nhead_k; // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); @@ -1084,9 +1133,39 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { - args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); - args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); - args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + if(qscale.type == quant_scale_enum::blockscale) + { + args.q_descale_ptr = + reinterpret_cast(q_descale_buf.GetDeviceBuffer()); + args.k_descale_ptr = + reinterpret_cast(k_descale_buf.GetDeviceBuffer()); + args.v_descale_ptr = + reinterpret_cast(v_descale_buf.GetDeviceBuffer()); + + args.block_scale_seqstart_q_ptr = + (mode == mode_enum::group ? block_scale_seqstart_q_buf.GetDeviceBuffer() + : nullptr); + args.block_scale_seqstart_k_ptr = + (mode == mode_enum::group ? block_scale_seqstart_k_buf.GetDeviceBuffer() + : nullptr); + + args.nhead_stride_q_descale = nhead_stride_q_descale; + args.nhead_stride_k_descale = nhead_stride_k_descale; + args.nhead_stride_v_descale = nhead_stride_v_descale; + + args.batch_stride_q_descale = batch_stride_q_descale; + args.batch_stride_k_descale = batch_stride_k_descale; + args.batch_stride_v_descale = batch_stride_v_descale; + + args.block_scale_size_q = block_scale_size_q_; + args.block_scale_size_kv = block_scale_size_kv_; + } + else + { + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + } args.rand_val_ptr = randval_buf.GetDeviceBuffer(); @@ -1589,14 +1668,42 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif // reference - ck_tile:: - reference_batched_gemm( + if(qscale.type == quant_scale_enum::blockscale) + { + const ck_tile::index_t q_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_q_host[wb]; + const ck_tile::index_t k_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; + ck_tile::reference_batched_quant_gemm( q_host_ref, k_host_ref, s_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale_s_host)); + ck_tile::idx_identity{}, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return value * scale_s * + q_descale_host(b_idx, + std::get<0>(idx), + q_offset + std::get<1>(idx) / block_scale_size_q_) * + k_descale_host(b_idx, + std::get<0>(idx) / nr, + k_offset + std::get<2>(idx) / block_scale_size_kv_); + }); + } + else + { + ck_tile:: + reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s_host)); + } if(0.f < logits_soft_cap) { @@ -1794,13 +1901,35 @@ fwd_result fmha_fwd_run(mode_enum mode, } } - ck_tile::reference_batched_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - oacc_element_func); + if(qscale.type == quant_scale_enum::blockscale) + { + const ck_tile::index_t v_offset = + (mode == mode_enum::batch) ? 0 : block_scale_seqstart_k_host[wb]; + ck_tile:: + reference_batched_quant_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::idx_identity{}, + [&](auto idx, auto value) { + return ck_tile::type_convert(value) * + v_descale_host(b_idx, + std::get<0>(idx) / nr, + v_offset + + std::get<2>(idx) / block_scale_size_kv_); + }, + ck_tile::idx_identity{}); + } + else + { + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + } ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); // clang-format off @@ -1808,7 +1937,6 @@ fwd_result fmha_fwd_run(mode_enum mode, if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on - auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err(o_host_result, o_host_ref, @@ -1866,31 +1994,33 @@ fwd_result fmha_fwd_run(mode_enum mode, if(json) { - dump_fmha_fwd_json_results(*json, - data_type, - mode == mode_enum::batch ? "batch" : "group", - io_layout(i_perm, o_perm), - batch, - nhead, - nhead_k, - seqlen_qs[0], - seqlen_ks[0], - seqlen_kpads[0], - hdim_q, - hdim_v, - scale_s, - p_drop, - lse, - qscale.type == quant_scale_enum::no_scale ? "no_scale" - : "pertensor", - bias.type == bias_enum::elementwise_bias - ? "elementwise_bias" - : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - is_v_rowmajor ? "r" : "c", - pass, - ave_time, - tflops, - gb_per_sec); + dump_fmha_fwd_json_results( + *json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + qscale.type == quant_scale_enum::no_scale + ? "no_scale" + : (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"), + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); } return pass ? fwd_result::success : fwd_result::failure; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index 59d4ac1707..feb28cba24 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -13,6 +13,7 @@ enum class quant_scale_enum { no_scale = 0, pertensor = 1, + blockscale, }; struct quant_scale_info @@ -25,6 +26,8 @@ struct quant_scale_info os << "n"; else if(type == quant_scale_enum::pertensor) os << "pt"; + else if(type == quant_scale_enum::blockscale) + os << "bs"; } static quant_scale_info decode(std::string str) @@ -38,6 +41,10 @@ struct quant_scale_info { info.type = quant_scale_enum::pertensor; } + else if(str == "bs" || str == "2") + { + info.type = quant_scale_enum::blockscale; + } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 596542eb9d..227f26c8f3 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -95,10 +95,11 @@ run_fp8bf16_tests() { for perm in 0 1 ; do for b in 1 2 ; do for hdim in 64 128 256 ; do + for scale in 1 2; do - $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=$scale -kname=$KNAME $COMMON_ARGS - done ; done ; done + done ; done ; done ; done } run_fp8fp32_tests() { diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 96e76f669d..a46ae509dd 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -37,6 +37,13 @@ struct scales return lhs_ * rhs; } + template + CK_TILE_HOST_DEVICE constexpr auto operator*(OtherScale other) const + { + auto new_scale = lhs_ * other; + return scales>(new_scale); + } + private: Scale lhs_; }; diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index 898d21574e..aa4bfa3f15 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -119,6 +119,18 @@ struct identity } }; +// Similar to identity, but takes an additional index parameter as the first argument. +// The index is ignored and only the second argument (value) is forwarded. +// Useful for indexed element-wise operations where the functor signature requires an index. +struct idx_identity +{ + template + CK_TILE_HOST_DEVICE constexpr T&& operator()(I&& /*idx*/, T&& arg) const noexcept + { + return std::forward(arg); + } +}; + namespace detail { // RemainLengths: sequence<...> diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index 63f13b1b16..d742426740 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -47,4 +47,44 @@ CK_TILE_HOST void reference_batched_gemm(const HostTensor& a_b_m_k, make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } +template +CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor& a_b_m_k, + const HostTensor& b_b_n_k, + HostTensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const int N = b_b_n_k.mDesc.get_lengths()[1]; + const int K = b_b_n_k.mDesc.get_lengths()[2]; + + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + AccDataType v_a = ck_tile::type_convert( + a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k))); + AccDataType v_b = ck_tile::type_convert( + b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k))); + + v_acc += v_a * v_b; + } + + c_b_m_n(batch, m, n) = ck_tile::type_convert( + acc_element_op(std::make_tuple(batch, m, n), v_acc)); + } + }; + + make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 3755a2bc71..7e0f704bef 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -12,6 +12,7 @@ enum class BlockAttentionQuantScaleEnum { NO_SCALE = 0, PERTENSOR = 1, + BLOCKSCALE, }; template @@ -27,5 +28,10 @@ struct BlockAttentionQuantScaleEnumToStr +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "blockscale"; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index adbedc5259..0039c57cfc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -168,6 +168,29 @@ struct FmhaFwdKernel const void* v_descale_ptr = nullptr; }; + struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs + { + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; + + ck_tile::index_t block_scale_size_q; + ck_tile::index_t block_scale_size_kv; + }; + + struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; + }; + + struct FmhaFwdGroupBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + const int32_t* block_scale_seqstart_q_ptr; + const int32_t* block_scale_seqstart_k_ptr; + }; + struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -243,9 +266,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t> { @@ -269,9 +295,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -328,6 +357,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -335,6 +367,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -343,6 +378,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -413,6 +450,23 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.batch_stride_q_descale = batch_stride_q_descale; + kargs.batch_stride_k_descale = batch_stride_k_descale; + kargs.batch_stride_v_descale = batch_stride_v_descale; + + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_kv = block_scale_size_kv; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -478,6 +532,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -485,6 +542,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -492,6 +552,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -528,6 +590,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -535,6 +600,9 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -542,6 +610,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -581,6 +651,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -588,6 +661,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -595,6 +671,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -631,6 +709,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -638,6 +719,9 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -645,6 +729,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -666,6 +752,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -685,6 +773,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -694,6 +785,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -763,6 +856,24 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.block_scale_size_q = block_scale_size_q; + kargs.block_scale_size_kv = block_scale_size_kv; + + kargs.block_scale_seqstart_q_ptr = + reinterpret_cast(block_scale_seqstart_q_ptr); + kargs.block_scale_seqstart_k_ptr = + reinterpret_cast(block_scale_seqstart_k_ptr); + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -814,6 +925,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -833,6 +946,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -841,6 +957,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -860,6 +978,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, + block_scale_seqstart_q_ptr, + block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -879,6 +999,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -887,6 +1010,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -909,6 +1034,8 @@ struct FmhaFwdKernel const void* seqstart_k_ptr, const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* block_scale_seqstart_q_ptr, + const void* block_scale_seqstart_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -928,6 +1055,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, @@ -936,6 +1066,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_size_q, + ck_tile::index_t block_scale_size_kv, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr, const void* sink_ptr = nullptr) @@ -955,6 +1087,8 @@ struct FmhaFwdKernel seqstart_k_ptr, seqlen_q_ptr, seqlen_k_ptr, + block_scale_seqstart_q_ptr, + block_scale_seqstart_k_ptr, hdim_q, hdim_v, num_head_q, @@ -974,6 +1108,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, sink_size, @@ -982,6 +1119,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_size_q, + block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr); @@ -1111,13 +1250,16 @@ struct FmhaFwdKernel const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + long_index_t batch_offset_q_descale = 0; + long_index_t batch_offset_k_descale = 0; + long_index_t batch_offset_v_descale = 0; const float sink_value = kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s @@ -1153,6 +1295,14 @@ struct FmhaFwdKernel { batch_offset_randval = query_start * kargs.stride_randval; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const long_index_t bquery_start = kargs.block_scale_seqstart_q_ptr[i_batch]; + const long_index_t bkey_start = kargs.block_scale_seqstart_k_ptr[i_batch]; + batch_offset_q_descale = bquery_start; + batch_offset_k_descale = bkey_start; + batch_offset_v_descale = bkey_start; + } batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) @@ -1220,6 +1370,15 @@ struct FmhaFwdKernel batch_offset_randval = static_cast(i_batch) * kargs.batch_stride_randval; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + batch_offset_q_descale = + static_cast(i_batch) * kargs.batch_stride_q_descale; + batch_offset_k_descale = + static_cast(i_batch) * kargs.batch_stride_k_descale; + batch_offset_v_descale = + static_cast(i_batch) * kargs.batch_stride_v_descale; + } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; // If cumulative seqlen pointers are provided, override per-batch effective lengths @@ -1540,7 +1699,8 @@ struct FmhaFwdKernel }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - auto o_acc_tile = [&]() { + + auto o_acc_tile = [&, i_nhead_ = i_nhead]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { // TODO - move global load of descale to pipeline @@ -1581,8 +1741,62 @@ struct FmhaFwdKernel block_indices, smem_ptr, dropout, + nullptr, + nullptr, + 1, sink_value); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + const float* q_descale_ptr = + reinterpret_cast(kargs.q_descale_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_q_descale + + batch_offset_q_descale; + const float* k_descale_ptr = + reinterpret_cast(kargs.k_descale_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k_descale + + batch_offset_k_descale; + const float* v_descale_ptr = + reinterpret_cast(kargs.v_descale_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v_descale + + batch_offset_v_descale; + + size_t idx = i_m0 / kargs.block_scale_size_q; + float q_descale = q_descale_ptr[idx]; + // BLOCKSCALE: P is scaled in exp2(x+shift) where shift=7 or 8 + // Both P and rowsum are scaled by 2^shift, canceling in normalization + // No additional scaling needed in p_compute_element_func or o_acc_element_func + + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + scales(q_descale), // s_acc_element_func + identity{}, // p_compute_element_func - No scaling (done in exp2) + identity{}, // o_acc_element_func - No dequant needed (canceled by rowsum) + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + k_descale_ptr, + v_descale_ptr, + kargs.block_scale_size_kv, + sink_value); + } else { return FmhaPipeline{}(q_dram_window, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index dcccdf541c..2fbc9fdb54 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -57,8 +57,13 @@ struct BlockFmhaPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr bool kHasSink = Problem::kHasSink; + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate) @@ -167,6 +172,9 @@ struct BlockFmhaPipelineQRKSVS const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, + const float* k_descale_ptr, + const float* v_descale_ptr, + const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -358,6 +366,13 @@ struct BlockFmhaPipelineQRKSVS static_assert(1 <= k1_loops); do { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + k_descale = k_descale_ptr[kv_idx]; + } // STAGE 1, QK gemm auto k_dram_window = make_tile_window( k_dram_block_window.get_bottom_tensor_view(), @@ -427,11 +442,20 @@ struct BlockFmhaPipelineQRKSVS k_lds_window); schedule_gemm0(); } + // dequant + auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -449,7 +473,7 @@ struct BlockFmhaPipelineQRKSVS { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -466,7 +490,7 @@ struct BlockFmhaPipelineQRKSVS } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -571,7 +595,21 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For BLOCKSCALE: precompute (m - shift) once per row + // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) + // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -579,13 +617,13 @@ struct BlockFmhaPipelineQRKSVS if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { @@ -676,18 +714,39 @@ struct BlockFmhaPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } + move_tile_window(v_dram_window, {0, kK1}); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + v_descale = v_descale_ptr[kv_idx]; + } // STAGE 3, KV gemm + auto o_acc0 = decltype(o_acc){}; + clear_tile(o_acc0); + + auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return o_acc0; + } + else + { + return o_acc; + } + }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_window); @@ -722,11 +781,16 @@ struct BlockFmhaPipelineQRKSVS // tail { block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), v_lds_window); block_sync_lds(); } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); + } } while(++i_total_loops < num_total_loop); // store lse @@ -846,6 +910,9 @@ struct BlockFmhaPipelineQRKSVS block_indices, smem_ptr, dropout, + nullptr, + nullptr, + 1, sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 797e572d58..81bd8d5ab5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -46,6 +46,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr auto QScaleEnum = Problem::QScaleEnum; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); @@ -64,6 +65,10 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasSink = Problem::kHasSink; + // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] + static constexpr float OCP_FP8_SHIFT = 8.0f; + static constexpr float FNUZ_FP8_SHIFT = 7.0f; + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || !kHasLogitsSoftCap)) || @@ -190,6 +195,9 @@ struct BlockFmhaPipelineQRKSVSAsync const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, + const float* k_descale_ptr, + const float* v_descale_ptr, + const index_t block_scale_size_kv, const float sink_v) const { static_assert( @@ -403,6 +411,13 @@ struct BlockFmhaPipelineQRKSVSAsync // main loop do { + float k_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + k_descale = k_descale_ptr[kv_idx]; + } // STAGE 1, QK gemm clear_tile(s_acc); // initialize C if constexpr(k0_loops > 1) @@ -449,11 +464,20 @@ struct BlockFmhaPipelineQRKSVSAsync sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); } __builtin_amdgcn_sched_barrier(1); + // dequant + auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return s_acc_element_func * k_descale; + } + else + return s_acc_element_func; + }(); // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -471,7 +495,7 @@ struct BlockFmhaPipelineQRKSVSAsync { const auto k_origin = k_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( @@ -488,7 +512,7 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + s_acc = tile_elementwise_in(s_acc_element_func_, s_acc); if constexpr(kHasLogitsSoftCap) { auto apply_logits_transform = @@ -630,7 +654,21 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); + // For BLOCKSCALE: precompute (m - shift) once per row + // Bias/Alibi/SoftCap: exp2(s - m + shift) = exp2(s - (m - shift)) + // else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift)) + auto validated_m = get_validated_m(m[i_idx]); + auto row_max = scale_s * validated_m; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { +#if CK_TILE_USE_OCP_FP8 + validated_m -= OCP_FP8_SHIFT; // for Bias/Alibi/SoftCap + row_max -= OCP_FP8_SHIFT; // for else branch +#else + validated_m -= FNUZ_FP8_SHIFT; + row_max -= FNUZ_FP8_SHIFT; +#endif + } #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -638,13 +676,13 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { if constexpr(kHasLogitsSoftCap) { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp2(s[i_j_idx] - validated_m); } else { @@ -735,7 +773,27 @@ struct BlockFmhaPipelineQRKSVSAsync #endif }(); + float v_descale = 1.0f; + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + // K and V share the same seqlen_k position within a block + const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; + v_descale = v_descale_ptr[kv_idx]; + } // STAGE 3, KV gemm + auto o_acc0 = decltype(o_acc){}; + clear_tile(o_acc0); + + auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + return o_acc0; + } + else + { + return o_acc; + } + }(); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { @@ -745,7 +803,7 @@ struct BlockFmhaPipelineQRKSVSAsync v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf } block_sync_lds(); - gemm_1(o_acc, + gemm_1(o_acc_, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), get_slice_tile( @@ -808,13 +866,19 @@ struct BlockFmhaPipelineQRKSVSAsync { block_sync_lds(); gemm_1( - o_acc, + o_acc_, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), get_slice_tile( v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); } + + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + tile_elementwise_inout( + [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0); + } } while(i_total_loops < num_total_loop); // store lse @@ -922,6 +986,9 @@ struct BlockFmhaPipelineQRKSVSAsync block_indices, smem_ptr, dropout, + nullptr, + nullptr, + 1, sink_v); } }; From e1c46ff548cf7bc8b0e1b41a3d559f05317ec2da Mon Sep 17 00:00:00 2001 From: chris-tsiaousis-hpc Date: Fri, 23 Jan 2026 21:39:03 +0100 Subject: [PATCH 18/42] Remove code duplications in batched gemm wmma (#3580) * Moved device struct for batched gemm wmma to a common file Signed-off-by: Chris Tsiaousis * Use the common device struct in the scaled batched gemm wmma implementation Signed-off-by: Chris Tsiaousis * Boy-scout: Remove unused includes and ambiguous comment Signed-off-by: Chris Tsiaousis * Moved pointer offset calculation and gridwise argument to common struct This change enables further code reduction by re-using the common structs for the batched gemm and batched gemm b scale wmma implementations. Signed-off-by: Chris Tsiaousis * Moved type string to the common struct of DeviceBatchedGemm_Wmma_CShuffleV3_Common" Signed-off-by: Chris Tsiaousis --------- Signed-off-by: Chris Tsiaousis --- .../device_batched_gemm_wmma_cshuffle_v3.hpp | 518 ++--------------- ..._batched_gemm_wmma_cshuffle_v3_b_scale.hpp | 533 ++---------------- ...e_batched_gemm_wmma_cshuffle_v3_common.hpp | 529 +++++++++++++++++ .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 116 ++++ 4 files changed, 719 insertions(+), 977 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index a18f108e47..94c339f643 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -13,105 +13,12 @@ #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_batched_gemm_wmma_cshuffle_v3( - typename GridwiseGemm::Argument karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - // The normal approach to batching would be to increase the grid size by just stretching out - // the grid Z dimension (which is the outermost dimension), but this depends on lower level - // functions not directly using the Z dimension for other calculations. As it turns out, k - // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now - // we will use the grid Y dimension for batching. This may be a bit fragile. - const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - // shift A matrices pointer for splitk - typename GridwiseGemm::AsGridPointer p_as_grid_shift; - static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { - using ADataType_ = - remove_cvref_t>; - p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + - splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; - }); - - // shift B matrices pointer for splitk - typename GridwiseGemm::BsGridPointer p_bs_grid_shift; - static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { - using BDataType_ = - remove_cvref_t>; - p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + - splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; - }); - - auto epilogue_args = EpilogueType{}; - - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; - ignore = compute_ptr_offset_of_batch; -#endif -} - /// @brief \"Universal\" Batched GEMM operation without SplitK support. /// /// @par Overview @@ -271,36 +178,6 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - private: - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; - }; - // GridwiseGemm using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< ALayout, @@ -354,330 +231,40 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm; // PermuteB not supported by DeviceBatchedGemm base class. + using DeviceGemmCommon = DeviceBatchedGemm_Wmma_CShuffleV3_Common< + GridwiseGemm, + Tuple, + Tuple, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // IsBScaled + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + // Argument - struct Argument : public GridwiseGemm::Argument - { - __host__ Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - index_t StrideC_, - index_t BatchStrideA_, - index_t BatchStrideB_, - index_t BatchStrideC_, - index_t Batch_, - index_t k_batch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CElementwiseOperation cde_element_op_, - bool is_reduce_ = false) - : GridwiseGemm::Argument(std::array{p_a_grid_}, - std::array{p_b_grid_}, - std::array{}, // p_ds_grid_ - p_c_grid_, - M_, - N_, - K_, - std::array{StrideA_}, - std::array{StrideB_}, - std::array{}, // StrideDs_ - StrideC_, - k_batch_, - a_element_op_, - b_element_op_, - cde_element_op_, - is_reduce_), - Batch(Batch_), - compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_} - { - } + using Argument = typename DeviceGemmCommon::Argument; - index_t Batch; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; - }; - - /// @brief Helper structure responsible for kernel invocation. - /// - /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU - /// kernel function. It usually determines the launched grid size prepares kernel - /// arguments as well as perform specific kernel configuration selection based on - /// runtime arguments. - /// - /// @note If appropriately configured it may measure kernel execution time. - /// - struct Invoker : public BaseInvoker - { - /// @brief This function issues GPU kernel execution. - /// @param arg The GPU kernel arguments. - /// @param stream_config The HIP stream configuration helper structure. - /// @return The kernel's average execution time (if time measurement is - /// enabled). - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - // The normal approach to batching would be to increase the grid size by just stretching - // out the grid Z dimension (which is the outermost dimension), but this depends on - // lower level functions not directly using the Z dimension for other calculations. As - // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. - // Therefore, for now we will use the grid Y dimension for batching. This may be a bit - // fragile. - gdy *= arg.Batch; - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); - - // Packed sizes are 1 for all implemented data types but we include it anyway - // for future compatibility. - // Note: the grid descriptors and size_a / size_b do *not* take batching into - // account, so we have to manually multiply overall buffer sizes for rotating - // memory by batch. - std::array size_as_buffers; - size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch; - - std::array size_bs_buffers; - size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch; - - ck::utility::RotatingMemWrapperMultiABD, - Tuple, - Tuple<>> - rotating_mem(arg_, - stream_config.rotating_count, - size_as_buffers, - size_bs_buffers, - std::array{}); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg_.p_e_grid, - 0, - arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_, - arg_.compute_ptr_offset_of_batch); - } - else - { - auto clear_workspace = [&]() { - // clear c mem - if(arg.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg.p_e_grid, - 0, - arg.Batch * arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - clear_workspace, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg, - arg.compute_ptr_offset_of_batch); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - else - { - // TODO: Implement - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.KBatch > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - return false; - } - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg); - } + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - // TODO: This is not part of the DeviceBatchedGemm base class but it was part of - // DeviceBatchedGemmV2. Remove? - // index_t GetKPerBlock() override { return KPerBlock; } - // bool GetPermuteA() override { return PermuteA; } - // bool GetPermuteB() override { return PermuteB; } - static auto MakeArgument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, @@ -762,48 +349,15 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm BlkGemmPipelineSchedulerToString{ - {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, - {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; - - std::map BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, - {BlockGemmPipelineVersion::v2, "v2"}, - {BlockGemmPipelineVersion::v3, "v3"}, - {BlockGemmPipelineVersion::v4, "v4"}, - {BlockGemmPipelineVersion::v5, "v5"}}; - - // clang-format off - str << "DeviceBatchedGemm_Wmma_CShuffleV3" - << "<" - << getGemmSpecializationString(GemmSpec) << ", " - << std::string(ALayout::name)[0] - << std::string(BLayout::name)[0] - << std::string(CLayout::name)[0] - << ">" - << " BlkSize: " - << BlockSize << ", " - << "BlkTile: " - << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " - << "WaveTile: " - << MPerWmma << "x"<(); } REGISTER_EXTRA_PRINTING_METHODS }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index b88f071a96..d682ca4ffa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -13,109 +13,12 @@ #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_batched_gemm_b_scale_wmma_cshuffle_v3( - typename GridwiseGemm::Argument karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - using EpilogueType = - typename std::conditional::type; - - constexpr index_t LDS_size = - GridwiseGemm::template GetSharedMemoryNumberOfByte(); - // The normal approach to batching would be to increase the grid size by just stretching out - // the grid Z dimension (which is the outermost dimension), but this depends on lower level - // functions not directly using the Z dimension for other calculations. As it turns out, k - // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now - // we will use the grid Y dimension for batching. This may be a bit fragile. - __shared__ char p_shared[LDS_size]; - - const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); - - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - const long_index_t b_scale_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - // shift A matrices pointer for splitk - typename GridwiseGemm::AsGridPointer p_as_grid_shift; - static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { - using ADataType_ = - remove_cvref_t>; - p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + - splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; - }); - - // shift B matrices pointer for splitk - typename GridwiseGemm::BsGridPointer p_bs_grid_shift; - static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { - using BDataType_ = - remove_cvref_t>; - p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + - splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; - }); - - auto epilogue_args = EpilogueType{}; - - GridwiseGemm::template Run( - p_as_grid_shift, - p_bs_grid_shift, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - karg.p_a_scale_grid, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; - ignore = compute_ptr_offset_of_batch; -#endif -} - /// @brief \"Universal\" Batched GEMM operation without SplitK support. /// /// @par Overview @@ -282,45 +185,6 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale static_assert(PermuteB == false, "Permute B functionality not supported by DeviceBatchedGemm operations.\n"); - struct ComputePtrOffsetOfStridedBatch - { - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - index_t BatchStrideC, - index_t BatchStrideScaleB) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideC_(BatchStrideC), - BatchStrideScaleB_(BatchStrideScaleB) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_) / GridwiseGemm::BPackedSize; - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - __host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideScaleB_); - } - - private: - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; - index_t BatchStrideScaleB_; - }; - // GridwiseGemm using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< ALayout, @@ -379,328 +243,40 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale PermuteA, // PermuteA not supported by DeviceBatchedGemm base class. PermuteB>; // PermuteB not supported by DeviceBatchedGemm base class. + using DeviceGemmCommon = DeviceBatchedGemm_Wmma_CShuffleV3_Common< + GridwiseGemm, + Tuple, + Tuple, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + true, // IsBScaled + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GridwiseGemm::BPackedSize, + BScaleDataType>; + // Argument - struct Argument : public GridwiseGemm::Argument - { - __host__ Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - index_t StrideC_, - index_t StrideScaleB_, - index_t BatchStrideA_, - index_t BatchStrideB_, - index_t BatchStrideC_, - index_t BatchStrideScaleB_, - const BScaleDataType* p_b_scale_grid_, - index_t Batch_, - index_t k_batch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_, - bool is_reduce_ = false) - : GridwiseGemm::Argument(std::array{p_a_grid_}, - std::array{p_b_grid_}, - std::array{}, // p_ds_grid_ - p_c_grid_, - M_, - N_, - K_, - std::array{StrideA_}, - std::array{StrideB_}, - std::array{}, // StrideDs_ - StrideC_, - 0, // StrideScaleA - StrideScaleB_, - nullptr, - p_b_scale_grid_, - k_batch_, - a_element_op_, - b_element_op_, - c_element_op_, - is_reduce_), - Batch(Batch_), - compute_ptr_offset_of_batch{ - BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_} - { - } + using Argument = typename DeviceGemmCommon::Argument; - index_t Batch; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; - }; - - /// @brief Helper structure responsible for kernel invocation. - /// - /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU - /// kernel function. It usually determines the launched grid size prepares kernel - /// arguments as well as perform specific kernel configuration selection based on - /// runtime arguments. - /// - /// @note If appropriately configured it may measure kernel execution time. - /// - struct Invoker : public BaseInvoker - { - /// @brief This function issues GPU kernel execution. - /// @param arg The GPU kernel arguments. - /// @param stream_config The HIP stream configuration helper structure. - /// @return The kernel's average execution time (if time measurement is - /// enabled). - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - // The normal approach to batching would be to increase the grid size by just stretching - // out the grid Z dimension (which is the outermost dimension), but this depends on - // lower level functions not directly using the Z dimension for other calculations. As - // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. - // Therefore, for now we will use the grid Y dimension for batching. This may be a bit - // fragile. - gdy *= arg.Batch; - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); - - // Packed sizes are 1 for all implemented data types but we include it anyway - // for future compatibility. - // Note: the grid descriptors and size_a / size_b do *not* take batching into - // account, so we have to manually multiply overall buffer sizes for rotating - // memory by batch. - std::array size_as_buffers; - size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch; - - std::array size_bs_buffers; - size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch; - - ck::utility::RotatingMemWrapperMultiABD, - Tuple, - Tuple<>> - rotating_mem(arg_, - stream_config.rotating_count, - size_as_buffers, - size_bs_buffers, - std::array{}); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - ck::utility::flush_icache(); - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg_.p_e_grid, - 0, - arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_, - arg_.compute_ptr_offset_of_batch); - } - else - { - auto clear_workspace = [&]() { - // clear c mem - if(arg.KBatch > 1) - // Note: we multiply by batch since we want to clear the C matrix for - // the whole batch. Untested since we don't have k batching ATM. - // Note: This seems incorrect for non-contiguous memory layouts for C - // (padding, gaps). - HIP_CHECK_ERROR( - hipMemsetAsync(arg.p_e_grid, - 0, - arg.Batch * arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - clear_workspace, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg, - arg.compute_ptr_offset_of_batch); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - else - { - throw std::runtime_error("Pipeline not implemented"); - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = kernel_batched_gemm_b_scale_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.KBatch > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - return false; - } - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg); - } + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } index_t GetKPerBlock() override { return KPerBlock; } @@ -801,48 +377,15 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale // polymorphic std::string GetTypeString() const override { - auto str = std::stringstream(); - - std::map BlkGemmPipelineSchedulerToString{ - {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, - {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; - - std::map BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, - {BlockGemmPipelineVersion::v2, "v2"}, - {BlockGemmPipelineVersion::v3, "v3"}, - {BlockGemmPipelineVersion::v4, "v4"}, - {BlockGemmPipelineVersion::v5, "v5"}}; - - // clang-format off - str << "DeviceBatchedGemm_Wmma_CShuffleV3_BScale" - << "<" - << getGemmSpecializationString(GemmSpec) << ", " - << std::string(ALayout::name)[0] - << std::string(BLayout::name)[0] - << std::string(CLayout::name)[0] - << ">" - << " BlkSize: " - << BlockSize << ", " - << "BlkTile: " - << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " - << "WaveTile: " - << MPerWmma << "x"<(); } REGISTER_EXTRA_PRINTING_METHODS }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 0000000000..59a820861c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,529 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template > +struct DeviceBatchedGemm_Wmma_CShuffleV3_Common +{ + struct ComputePtrOffsetOfStridedBatch + { + template > + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + template > + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideScaleB_(BatchStrideScaleB) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + static_assert(BPackedSize != 0); + static_assert(IsBScaled || (!IsBScaled && BPackedSize == 1)); + return g_idx * static_cast(BatchStrideB_) / BPackedSize; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + __host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(*BatchStrideScaleB_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + std::optional BatchStrideScaleB_; + }; + + struct Argument : public GridwiseGemm::Argument + { + using ADataType = typename AsDataType::DataType; + using BDataType = typename BsDataType::DataType; + template > + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t Batch_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation cde_element_op_, + bool is_reduce_ = false) + : GridwiseGemm::Argument(std::array{p_a_grid_}, + std::array{p_b_grid_}, + std::array{}, // p_ds_grid_ + p_c_grid_, + M_, + N_, + K_, + std::array{StrideA_}, + std::array{StrideB_}, + std::array{}, // StrideDs_ + StrideC_, + k_batch_, + a_element_op_, + b_element_op_, + cde_element_op_, + is_reduce_), + Batch(Batch_), + compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_} + { + static_assert(std::is_same_v>); + } + + template > + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t BatchStrideScaleB_, + const BScaleDataType* p_b_scale_grid_, + index_t Batch_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : GridwiseGemm::Argument(std::array{p_a_grid_}, + std::array{p_b_grid_}, + std::array{}, // p_ds_grid_ + p_c_grid_, + M_, + N_, + K_, + std::array{StrideA_}, + std::array{StrideB_}, + std::array{}, // StrideDs_ + StrideC_, + 0, // StrideScaleA + StrideScaleB_, + nullptr, + p_b_scale_grid_, + k_batch_, + a_element_op_, + b_element_op_, + c_element_op_, + is_reduce_), + Batch(Batch_), + compute_ptr_offset_of_batch{ + BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_} + { + static_assert(!std::is_same_v>); + } + + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + }; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + // The normal approach to batching would be to increase the grid size by just stretching + // out the grid Z dimension (which is the outermost dimension), but this depends on + // lower level functions not directly using the Z dimension for other calculations. As + // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. + // Therefore, for now we will use the grid Y dimension for batching. This may be a bit + // fragile. + gdy *= arg.Batch; + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); + + // Packed sizes are 1 for all implemented data types but we include it anyway + // for future compatibility. + // Note: the grid descriptors and size_a / size_b do *not* take batching into + // account, so we have to manually multiply overall buffer sizes for rotating + // memory by batch. + std::array size_as_buffers; + size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * + GridwiseGemm::NumATensor / GridwiseGemm::APackedSize * + arg_.Batch; + + std::array size_bs_buffers; + size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * + GridwiseGemm::NumBTensor / GridwiseGemm::BPackedSize * + arg_.Batch; + + ck::utility:: + RotatingMemWrapperMultiABD> + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + std::array{}); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + ck::utility::flush_icache(); + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg_.p_e_grid, + 0, + arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_, + arg_.compute_ptr_offset_of_batch); + } + else + { + auto clear_workspace = [&]() { + // clear c mem + if(arg.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg.p_e_grid, + 0, + arg.Batch * arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + arg.compute_ptr_offset_of_batch); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + using ComputePtrOffsetOfStridedBatch = decltype(arg.compute_ptr_offset_of_batch); + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + } + else + { + throw std::runtime_error("Pipeline not implemented"); + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + IsBScaled>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + template + static std::string GetTypeString() + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + constexpr auto type = []() { + if constexpr(IsBScaled) + { + return "DeviceBatchedGemm_Wmma_CShuffleV3_BScale"; + } + else + { + return "DeviceBatchedGemm_Wmma_CShuffleV3"; + } + }(); + // clang-format off + str << type + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + // shift A matrices pointer for splitk + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + // shift B matrices pointer for splitk + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + + auto epilogue_args = EpilogueType{}; + + if constexpr(IsBScaled) + { + const long_index_t b_scale_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx)); + + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + karg.p_a_scale_grid, + karg.p_b_scale_grid + b_scale_batch_offset + + splitk_batch_offset.scale_b_k_split_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } + else + { + GridwiseGemm::template Run( + p_as_grid_shift, + p_bs_grid_shift, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = compute_ptr_offset_of_batch; +#endif +} + template Date: Fri, 23 Jan 2026 16:14:22 -0700 Subject: [PATCH 19/42] [CK_TILE] Fix alignment in Stream-K workspace buffer (#3625) * Fix alignment issue in Stream-K workspace buffer In CK Tile Stream-K, the workspace buffer is used to hold flags and partials, where the first i bytes holds the flags and the remaining bytes hold partials. This change adds padding to the flags prefix of the workspace buffer to ensure the number of bytes is 128B-aligned. Without this alignment, since workgroups do not skip cache when reading from partials, they may read stale partials data in cache, leading to incorrect results. The added padding avoids the stale data reading. This change also re-enables the test_ck_tile_streamk_reduction tests. * Compute reference GEMM on GPU for test verification to decrease testing time --- .../streamk_gemm_tile_partitioner.hpp | 3 +- .../streamk_gemm_tile_partitioner_impl.hpp | 5 +- test/ck_tile/gemm_streamk/CMakeLists.txt | 7 ++- .../gemm_streamk/test_gemm_streamk_util.hpp | 32 ++++++++++--- .../test_streamk_tile_partitioner.cpp | 37 ++++++++++++++- .../test_streamk_tile_partitioner_common.hpp | 47 +++++++++++++++++-- 6 files changed, 115 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index 0b0f6c18ef..f028ba0c62 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -42,7 +42,8 @@ struct StreamKTilePartitionerBase CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept; /** - * @brief Calculates the total space needed for the flags buffer. + * @brief Calculates the total space needed for the flags buffer whose total byte size is + * 128B-aligned. * * @return index_t The number of bytes needed for the flags buffer. */ diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp index 1764a1ce83..f80eec844c 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp @@ -58,7 +58,10 @@ CK_TILE_HOST_DEVICE index_t StreamKTilePartitionerBase::get_flags_buffer_size() const noexcept { - return sizeof(index_t) * sk_ctas_; + constexpr index_t alignment = 128; + const index_t required_bytes = sizeof(index_t) * sk_ctas_; + const index_t padded_bytes = ck_tile::integer_least_multiple(required_bytes, alignment); + return padded_bytes; } template diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 6aaa145c7d..1390e5ee07 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -23,10 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) - # TODO: Renable once transient bug for reduction is resolved. - # add_gtest_executable(test_ck_tile_streamk_reduction - # ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp - # test_gemm_streamk_util.cpp) + add_gtest_executable(test_ck_tile_streamk_reduction + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp + test_gemm_streamk_util.cpp) add_gtest_executable(test_ck_tile_streamk_smoke ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 237dc24c3b..96f90a5c2d 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -262,20 +262,40 @@ class TestCkTileStreamK : public ::testing::Test c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - ck_tile::HostTensor c_m_n_host_ref( + // Calculate reference GEMM on the GPU + ck_tile::HostTensor c_m_n_dev_ref( f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - c_m_n_host_ref.SetZero(); + ck_tile::DeviceMem ref_c_m_n_dev_buf(c_m_n_dev_ref.get_element_space_size_in_bytes()); + ref_c_m_n_dev_buf.SetZero(); - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_ref); + ADataType* a_m_k_dev_ref_ptr = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* b_k_n_dev_ref_ptr = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* c_m_n_dev_ref_ptr = static_cast(ref_c_m_n_dev_buf.GetDeviceBuffer()); + ck_tile::reference_gemm_gpu(a_m_k_dev_ref_ptr, + b_k_n_dev_ref_ptr, + c_m_n_dev_ref_ptr, + M, + N, + K, + stride_A, + stride_B, + stride_C); + ref_c_m_n_dev_buf.FromDevice(c_m_n_dev_ref.data()); const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + *std::max_element(c_m_n_dev_ref.mData.begin(), c_m_n_dev_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( K, num_accumulations_per_tile, max_accumulated_value); bool pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_host_ref, + c_m_n_dev_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index 637f71c04f..30b1b878c5 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -51,6 +51,39 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase) validate_streamk_base_constructor(expected_values, tile_partitioner); } +TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128); +} + +TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes) +{ + using Config = StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128); +} + +TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes) +{ + using Config = StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes; + + ck_tile::StreamKTilePartitionerBase + tile_partitioner{Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256); +} + TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy) { using Config = StreamKTilePartitionerBaseConfigDP2TileSK; @@ -71,7 +104,9 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy) ck_tile::index_t expected_partials_size = sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID; - ck_tile::index_t expected_flags_size = sizeof(ck_tile::index_t) * Config::GRID; + // Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of + // the flags array is 128B-aligned. + ck_tile::index_t expected_flags_size = 128; EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), expected_partials_size + expected_flags_size); diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index 3daec049a7..31217ba101 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -198,9 +198,11 @@ struct StreamKTilePartitionerBaseConfig struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitionerBaseConfig { - static constexpr ck_tile::index_t M = 28; - static constexpr ck_tile::index_t N = 4; - static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + // The minimum number of bytes needed for the flags array is GRID * 4B = 3 * 4B = 12B. To ensure + // the total byte size of the array is 128B-aligned, the flags array must be 128B. static constexpr ck_tile::index_t GRID = 3; static constexpr ck_tile::index_t M_TILE = 4; @@ -212,6 +214,45 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner ck_tile::sequence>; }; +struct StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes + : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 32; + // The minimum number of bytes needed for the flags array is GRID * 4B = 32 * 4B = 128B. So, the + // number of bytes for the flags array should be 128B. + static constexpr ck_tile::index_t GRID = 32; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 1; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + +struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes + : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 28; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 33; + // The minimum number of bytes needed for the flags array is GRID * 4B = 33 * 4B = 132B. So, the + // number of bytes for the flags array should be 2 * 128B = 256B to ensure the total byte size + // of the array is 128B-aligned. + static constexpr ck_tile::index_t GRID = 33; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 1; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile : public StreamKTilePartitionerBaseConfig { From 7ac379428408337a231a86f8a8b7353b5b45aa2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Sun, 25 Jan 2026 14:42:23 +0200 Subject: [PATCH 20/42] Add new instances for merging multiple fwd conv groups into a single GEMM batch. Allow group merging for C > 1 when vector load/store size is 1 for the output tensor. (#3639) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ville Pietilä <> --- ...vice_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 2 +- ...ice_grouped_conv_fwd_xdl_merged_groups_instance.hpp | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index cc343f6f69..d3e0d6057d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1513,7 +1513,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle if constexpr(NumGroupsToMerge > 1) { - if(!(C == 1)) + if(!(C == 1) && CDEBlockTransferScalarPerVector_NPerBlock > 1) { return false; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 944e68f192..18abcb1613 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -116,9 +116,13 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8> // clang-format on >; From 054c437dec3bc0d0059f045dc768b950db315846 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 26 Jan 2026 09:23:19 -0800 Subject: [PATCH 21/42] add dockerfile for manylinux (#3651) --- Dockerfile.manylinux | 101 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 Dockerfile.manylinux diff --git a/Dockerfile.manylinux b/Dockerfile.manylinux new file mode 100644 index 0000000000..0683bcd4a6 --- /dev/null +++ b/Dockerfile.manylinux @@ -0,0 +1,101 @@ +FROM ghcr.io/rocm/therock_build_manylinux_x86_64:latest +ARG DEBIAN_FRONTEND=noninteractive +ARG ROCMVERSION=7.2 +ARG compiler_version="" +ARG compiler_commit="" +ARG CK_SCCACHE="" +ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn +ENV DEBIAN_FRONTEND=noninteractive + +USER root + +# Add rocm repository +RUN dnf clean all && dnf update -y && dnf -v install wget gnupg2 curl -y + +RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2.70200-1.el8.noarch.rpm && \ + dnf install ./amdgpu-install-7.2.70200-1.el8.noarch.rpm -y && \ + dnf update -y && \ + dnf install python3-setuptools python3-wheel -y && \ + dnf install rocm-dev -y + +## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined +ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache +ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin +ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} +ENV CK_SCCACHE=$CK_SCCACHE +RUN if [ "$CK_SCCACHE" != "" ]; then \ + mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ + curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ + chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \ + fi + +# Install dependencies +RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \ + cmake \ + clang-tools-extra \ + gcc-c++ \ + libstdc++ \ + libstdc++-devel \ + libstdc++-static \ + git \ + hip-rocclr \ + jq \ + mpich \ + net-tools \ + pkg-config \ + redis \ + sshpass \ + stunnel \ + vim \ + nano \ + zip \ + openssh-server \ + kmod && \ + dnf clean all && \ + rm -rf /var/lib/apt/lists/* && \ + rm -rf amdgpu-install* && \ +#Install latest ccache + git clone https://github.com/ccache/ccache.git && \ + cd ccache && mkdir build && cd build && cmake .. && make install && \ +#Install ClangBuildAnalyzer + git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \ + cd ClangBuildAnalyzer/ && \ + make -f projects/make/Makefile && \ + cd / && \ +#Install latest cppcheck + git clone https://github.com/danmar/cppcheck.git && \ + cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \ + cd / && \ +# Install packages for processing the performance results + pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ +# Add render group + groupadd -f render && \ +# Install the new rocm-cmake version + git clone -b master https://github.com/ROCm/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install + +WORKDIR / +# Add alternative compilers, if necessary +ENV compiler_version=$compiler_version +ENV compiler_commit=$compiler_commit +RUN sh -c "echo compiler version = '$compiler_version'" && \ + sh -c "echo compiler commit = '$compiler_commit'" + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + From de59c0716c631edfa4742e4309ee11d4379ef6e8 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Mon, 26 Jan 2026 10:08:55 -0800 Subject: [PATCH 22/42] Optimize sequence metaprogramming utilities to reduce template instantiation depth (#3585) This change significantly improves compile-time performance by reducing template instantiation depth for sequence generation and merging operations: Optimizations: - sequence_gen: Reduce instantiation depth from O(log N) to O(1) by using __make_integer_seq to generate indices in a single step, then applying the functor via pack expansion - uniform_sequence_gen: Similarly optimized to O(1) depth using __make_integer_seq with a helper that applies a constant value via pack expansion - sequence_merge: Reduce depth from O(N) to O(log N) using binary tree reduction strategy. Added direct concatenation specializations for 1-4 sequences to avoid recursion in common cases, falling back to binary tree merging for 5+ sequences Documentation: - Added extensive inline comments explaining why sequence_merge cannot achieve O(1) depth like sequence_gen (requires computing cumulative sequence lengths from heterogeneous inputs, inherently requiring recursion) - Documented the binary tree reduction approach and why it's superior to fold expressions for this use case Testing: - Added comprehensive unit tests for uniform_sequence_gen with different values, sizes, and edge cases - Added tests for sequence_gen with custom functors (double, square, identity, constant) to verify the new implementation works with arbitrary functors - Added tests for sequence_merge with 4, 5, and many sequences to verify both the direct concatenation path and binary tree reduction path - Added tests for empty sequence edge cases --- include/ck/utility/sequence.hpp | 152 +++++++++++++----- .../ck/utility/statically_indexed_array.hpp | 1 + test/util/unit_sequence.cpp | 134 +++++++++++++++ 3 files changed, 247 insertions(+), 40 deletions(-) diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 6e68690048..3a45d52bd3 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -199,55 +199,113 @@ template using make_index_sequence = typename __make_integer_seq::seq_type; -// merge sequence -template -struct sequence_merge +// merge sequence - optimized to avoid recursive instantiation +// +// Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1) +// instantiation depth, sequence_merge cannot achieve O(1) depth. Here's why: +// +// - sequence_gen and uniform_sequence_gen generate a SINGLE output sequence where each +// element can be computed independently: output[i] = f(i) +// +// - sequence_merge takes MULTIPLE input sequences with different, unknown lengths. +// To compute output[i], we need to know: +// 1. Which input sequence contains this index +// 2. The offset within that sequence +// This requires computing cumulative sequence lengths, which requires recursion/iteration. +// +// Instead, we use a binary tree reduction approach that achieves O(log N) instantiation depth: +// - Base cases handle 1-4 sequences directly (O(1) for common cases) +// - Recursive case merges pairs then combines: merge(s1,s2) + merge(s3,s4,...) +// - This gives O(log N) depth, which is optimal for merging heterogeneous sequences +// +// Alternative considered: Fold expressions (... + sequences) would give O(N) depth due to +// linear dependency chain, so binary tree is superior. +// +namespace detail { + +// Helper to concatenate multiple sequences in one step using fold expression +template +struct sequence_merge_impl; + +// Base case: single sequence +template +struct sequence_merge_impl> { - using type = typename sequence_merge::type>::type; + using type = Sequence; }; +// Two sequences: direct concatenation template -struct sequence_merge, Sequence> +struct sequence_merge_impl, Sequence> { using type = Sequence; }; -template -struct sequence_merge +// Three sequences: direct concatenation (avoids one level of recursion) +template +struct sequence_merge_impl, Sequence, Sequence> { - using type = Seq; + using type = Sequence; }; -// generate sequence +// Four sequences: direct concatenation +template +struct sequence_merge_impl, Sequence, Sequence, Sequence> +{ + using type = Sequence; +}; + +// General case: binary tree reduction (O(log N) depth instead of O(N)) +template +struct sequence_merge_impl +{ + // Merge pairs first, then recurse + using left = typename sequence_merge_impl::type; + using right = typename sequence_merge_impl::type; + using type = typename sequence_merge_impl::type; +}; + +} // namespace detail + +template +struct sequence_merge +{ + using type = typename detail::sequence_merge_impl::type; +}; + +template <> +struct sequence_merge<> +{ + using type = Sequence<>; +}; + +// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation +namespace detail { + +// Helper that applies functor F to indices and produces a Sequence +// __make_integer_seq produces sequence_gen_helper +template +struct sequence_gen_helper +{ + // Apply a functor F to all indices at once via pack expansion (O(1) depth) + template + using apply = Sequence{})...>; +}; + +} // namespace detail + template struct sequence_gen { - template - struct sequence_gen_impl - { - static constexpr index_t NRemainLeft = NRemain / 2; - static constexpr index_t NRemainRight = NRemain - NRemainLeft; - static constexpr index_t IMiddle = IBegin + NRemainLeft; + using type = + typename __make_integer_seq::template apply; +}; - using type = typename sequence_merge< - typename sequence_gen_impl::type, - typename sequence_gen_impl::type>::type; - }; - - template - struct sequence_gen_impl - { - static constexpr index_t Is = G{}(Number{}); - using type = Sequence; - }; - - template - struct sequence_gen_impl - { - using type = Sequence<>; - }; - - using type = typename sequence_gen_impl<0, NSize, F>::type; +template +struct sequence_gen<0, F> +{ + using type = Sequence<>; }; // arithmetic sequence @@ -283,16 +341,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1> using type = typename __make_integer_seq::type; }; -// uniform sequence +// uniform sequence - optimized using __make_integer_seq +namespace detail { + +template +struct uniform_sequence_helper +{ + // Apply a constant value to all indices via pack expansion + template + using apply = Sequence<((void)Is, Value)...>; +}; + +} // namespace detail + template struct uniform_sequence_gen { - struct F - { - __host__ __device__ constexpr index_t operator()(index_t) const { return I; } - }; + using type = typename __make_integer_seq:: + template apply; +}; - using type = typename sequence_gen::type; +template +struct uniform_sequence_gen<0, I> +{ + using type = Sequence<>; }; // reverse inclusive scan (with init) sequence diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp index d0735a32f6..f3d73e84a7 100644 --- a/include/ck/utility/statically_indexed_array.hpp +++ b/include/ck/utility/statically_indexed_array.hpp @@ -20,6 +20,7 @@ struct tuple_concat, Tuple> using type = Tuple; }; +// StaticallyIndexedArrayImpl uses binary split for O(log N) depth template struct StaticallyIndexedArrayImpl { diff --git a/test/util/unit_sequence.cpp b/test/util/unit_sequence.cpp index f09fd86e06..9e62b9a6c0 100644 --- a/test/util/unit_sequence.cpp +++ b/test/util/unit_sequence.cpp @@ -229,6 +229,32 @@ TEST(SequenceGen, UniformSequenceZeroSize) EXPECT_TRUE((is_same::value)); } +TEST(SequenceGen, UniformSequenceSingleElement) +{ + using Result = typename uniform_sequence_gen<1, 99>::type; + using Expected = Sequence<99>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, UniformSequenceDifferentValues) +{ + using Result1 = typename uniform_sequence_gen<3, 0>::type; + using Expected1 = Sequence<0, 0, 0>; + EXPECT_TRUE((is_same::value)); + + using Result2 = typename uniform_sequence_gen<4, -5>::type; + using Expected2 = Sequence<-5, -5, -5, -5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, UniformSequenceLargeSize) +{ + // Test with larger size to verify __make_integer_seq implementation + using Result = typename uniform_sequence_gen<16, 7>::type; + using Expected = Sequence<7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7>; + EXPECT_TRUE((is_same::value)); +} + // Test make_index_sequence TEST(SequenceGen, MakeIndexSequence) { @@ -244,6 +270,54 @@ TEST(SequenceGen, MakeIndexSequenceZero) EXPECT_TRUE((is_same::value)); } +// Test sequence_gen with custom functors +TEST(SequenceGen, SequenceGenWithDoubleFunctor) +{ + struct DoubleFunctor + { + __host__ __device__ constexpr index_t operator()(index_t i) const { return i * 2; } + }; + using Result = typename sequence_gen<5, DoubleFunctor>::type; + using Expected = Sequence<0, 2, 4, 6, 8>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, SequenceGenWithSquareFunctor) +{ + struct SquareFunctor + { + __host__ __device__ constexpr index_t operator()(index_t i) const { return i * i; } + }; + using Result = typename sequence_gen<5, SquareFunctor>::type; + using Expected = Sequence<0, 1, 4, 9, 16>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, SequenceGenZeroSize) +{ + struct IdentityFunctor + { + __host__ __device__ constexpr index_t operator()(index_t i) const { return i; } + }; + using Result = typename sequence_gen<0, IdentityFunctor>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); + // Also verify non-zero size works with identity + using Result5 = typename sequence_gen<5, IdentityFunctor>::type; + EXPECT_TRUE((is_same>::value)); +} + +TEST(SequenceGen, SequenceGenSingleElement) +{ + struct ConstantFunctor + { + __host__ __device__ constexpr index_t operator()(index_t) const { return 42; } + }; + using Result = typename sequence_gen<1, ConstantFunctor>::type; + using Expected = Sequence<42>; + EXPECT_TRUE((is_same::value)); +} + // Test sequence_merge TEST(SequenceMerge, MergeTwoSequences) { @@ -272,6 +346,66 @@ TEST(SequenceMerge, MergeSingleSequence) EXPECT_TRUE((is_same::value)); } +TEST(SequenceMerge, MergeFourSequences) +{ + // Test the 4-sequence specialization + using Seq1 = Sequence<1>; + using Seq2 = Sequence<2, 3>; + using Seq3 = Sequence<4, 5, 6>; + using Seq4 = Sequence<7, 8>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeFiveSequences) +{ + // Test the binary tree reduction path (5+ sequences) + using Seq1 = Sequence<1>; + using Seq2 = Sequence<2>; + using Seq3 = Sequence<3>; + using Seq4 = Sequence<4>; + using Seq5 = Sequence<5>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeManySequences) +{ + // Test with many sequences to stress the binary tree reduction + using Seq1 = Sequence<1>; + using Seq2 = Sequence<2>; + using Seq3 = Sequence<3, 4>; + using Seq4 = Sequence<5>; + using Seq5 = Sequence<6, 7>; + using Seq6 = Sequence<8>; + using Seq7 = Sequence<9, 10>; + using Seq8 = Sequence<11, 12>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeEmptySequences) +{ + // Test merging empty sequences + using Seq1 = Sequence<>; + using Seq2 = Sequence<1, 2>; + using Seq3 = Sequence<>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeZeroSequences) +{ + // Test the empty specialization + using Result = typename sequence_merge<>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + // Test sequence_split TEST(SequenceSplit, SplitInMiddle) { From 917f35553a46286eb3364abec4de5267d2aa92b0 Mon Sep 17 00:00:00 2001 From: chris-tsiaousis-hpc Date: Mon, 26 Jan 2026 19:20:30 +0100 Subject: [PATCH 23/42] Remove code duplications in batched gemm (multi D) gemm (multi D) wmma (#3617) * Added common struct to enable code reduction in gemm gemm and gemm multi_d gemm multi_d wmma implementation This file includes all shared components. The (shared between the two implementations) kernel, the pointer offset computation struct, the grid descriptor creator and definitions, the invoker struct and the argument struct. Signed-off-by: Chris Tsiaousis * Used the common struct in the batched gemm gemm wmma cshuffle v3 implementation Signed-off-by: Chris Tsiaousis * Used the shared structs in the gemm multiple D gemm multiple D wmma cshuffle v3 implementation Signed-off-by: Chris Tsiaousis * Boy-scout: IWYU paradigm in the gemm gemm and gemm multiple D gemm multiple D wmma cshuffle v3 implementations Signed-off-by: Chris Tsiaousis --------- Signed-off-by: Chris Tsiaousis --- ...ice_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 618 +++--------- ...ched_gemm_gemm_wmma_cshuffle_v3_common.hpp | 902 ++++++++++++++++++ ...ple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp | 816 +++------------- 3 files changed, 1173 insertions(+), 1163 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp index 45ec3a2065..6b1144047f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -3,77 +3,21 @@ #pragma once -#include #include -#include #include #include #include "ck/ck.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" #include "ck/utility/tuple.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b0_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); - const long_index_t b1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - - GridwiseOp::template Run( - arg.p_a_grid + a_batch_offset, - arg.p_b0_grid + b0_batch_offset, - Tuple<>{}, // p_d0s_grid - arg.p_b1_grid + b1_batch_offset, - Tuple<>{}, // p_d1s_grid - arg.p_c_grid + c_batch_offset, - p_shared, - arg.a_grid_desc, - arg.b0_grid_desc, - Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - arg.b1_grid_desc, - Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op, - arg.b0_element_op, - arg.acc_element_op, - arg.b1_element_op, - arg.c_element_op, - arg.block_2_ctile_map); -#else - ignore = arg; -#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) -} - // Computes C = A * B0 * B1 // MN = MK * KL * LN // ^^^^^^ (Acc0) @@ -157,88 +101,47 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, - Sequence, - GemmSpec, - TensorSpecialization::Default, // ASpec - TensorSpecialization::Default, // B0Spec - TensorSpecialization::Default, // B1Spec - TensorSpecialization::Default>; // CSpec - - __host__ __device__ static auto - MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, - const std::array& a_g_m_k_strides_vec) - { - return Transform::MakeAGridDescriptor_AK0_M_AK1( - Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, - const std::array& b0_g_l_k_strides_vec) - { - return Transform::MakeB0GridDescriptor_BK0_N_BK1( - Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, - const std::array& b1_g_n_l_strides_vec) - { - return Transform::MakeB1GridDescriptor_BK0_N_BK1( - Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), - Number{}); - } - - using AGridDesc = decltype(MakeAGridDescriptor({}, {})); - using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); - using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); - using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); - - struct ComputeBasePtrOfStridedBatch - { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB0, - index_t BatchStrideB1, - index_t BatchStrideC) - : BatchStrideA_(BatchStrideA), - BatchStrideB0_(BatchStrideB0), - BatchStrideB1_(BatchStrideB1), - BatchStrideC_(BatchStrideC) - { - } - - __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB0_); - } - - __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB1_); - } - - __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - private: - index_t BatchStrideA_; - index_t BatchStrideB0_; - index_t BatchStrideB1_; - index_t BatchStrideC_; - }; + using DeviceGemmGemmCommonBase = + DeviceGemmGemm_Wmma_CShuffleV3_Common, // D0sLayout + B1Layout, + Tuple<>, // D1sLayout + CLayout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + CDataType, + Tuple<>, // D0sDataType + Tuple<>, // D1sDataType + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector + CShuffleBlockTransferScalarPerVector_NPerBlock, + false>; // IsMultiD // GridwiseOp using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< @@ -260,12 +163,12 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, // Ds0GridDesc - B1GridDesc, + typename DeviceGemmGemmCommonBase::B1GridDesc, Tuple<>, // Ds1GridDesc - CGridDesc_M_N, + typename DeviceGemmGemmCommonBase::CGridDesc_M_N, // Tiling Family MPerBlock, LPerBlock, @@ -312,339 +215,67 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm; - struct RawArg : public BaseArgument + using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg< + DeviceOp, + GemmSpec, + ALayout, + B0layout, + Tuple<>, // D0sLayout + B1Layout, + Tuple<>, // D1sLayout + CLayout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + CDataType, + Tuple<>, // D0sDataType, + Tuple<>, // D1sDataType, + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector + CShuffleBlockTransferScalarPerVector_NPerBlock, + false>; // IsMultiD + // Invoker + using Invoker = typename DeviceGemmGemmCommon::Invoker; + + // Argument + using Argument = typename DeviceGemmGemmCommon::Argument; + + static bool IsSupportedArgument(const Argument& arg) { - using arr3 = std::array; - - RawArg(const ADataType* p_a_grid_, - const B0DataType* p_b0_grid_, - const B1DataType* p_b1_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t O_, - index_t Batch, - index_t StrideA, - index_t StrideB0, - index_t StrideB1, - index_t StrideC, - index_t BatchStrideA, - index_t BatchStrideB0, - index_t BatchStrideB1, - index_t BatchStrideC, - AElementwiseOperation a_element_op_, - B0ElementwiseOperation b0_element_op_, - AccElementwiseOperation acc_element_op_, - B1ElementwiseOperation b1_element_op_, - CElementwiseOperation c_element_op_) - : p_a_grid{p_a_grid_}, - p_b0_grid{p_b0_grid_}, - p_b1_grid{p_b1_grid_}, - p_c_grid{p_c_grid_}, - M{M_}, - N{N_}, - K{K_}, - O{O_}, - batch_count{Batch}, - a_element_op{a_element_op_}, - b0_element_op{b0_element_op_}, - acc_element_op{acc_element_op_}, - b1_element_op{b1_element_op_}, - c_element_op{c_element_op_}, - compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC} - { - - a_g_m_k_lengths = arr3{batch_count, M, K}; - a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] - - b0_g_n_k_lengths = arr3{batch_count, N, K}; - b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] - - b1_g_o_n_lengths = arr3{batch_count, O, N}; - b1_g_o_n_strides = - is_same_v - ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] - : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] - - c_g_m_o_lengths = arr3{batch_count, M, O}; - c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O] - - a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); - b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); - b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); - c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides); - c_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); - block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1); - } - // Pointers - const ADataType* p_a_grid; - const B0DataType* p_b0_grid; - const B1DataType* p_b1_grid; - CDataType* p_c_grid; - - // Raw Problem Size - index_t M; - index_t N; - index_t K; - index_t O; - index_t batch_count; - - arr3 a_g_m_k_lengths; - arr3 a_g_m_k_strides; - arr3 b0_g_n_k_lengths; - arr3 b0_g_n_k_strides; - arr3 b1_g_o_n_lengths; - arr3 b1_g_o_n_strides; - arr3 c_g_m_o_lengths; - arr3 c_g_m_o_strides; - - AElementwiseOperation a_element_op; - B0ElementwiseOperation b0_element_op; - AccElementwiseOperation acc_element_op; - B1ElementwiseOperation b1_element_op; - CElementwiseOperation c_element_op; - - // Grid descriptors and other mem calculators - AGridDesc a_grid_desc; - B0GridDesc b0_grid_desc; - B1GridDesc b1_grid_desc; - CGridDesc_M_N c_grid_desc_m_n; - typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock; - - typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map; - - ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; - }; - - static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) - { - // Print lambda with env check and printf() style formmating. - const char* curFunc = __func__; - auto print = [&curFunc](const char* format, ...) -> void { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wformat-nonliteral" -#endif - va_list args; - va_start(args, format); - std::vfprintf(stdout, format, args); - va_end(args); -#if defined(__clang__) -#pragma clang diagnostic pop -#endif - std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; - } - }; - - if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - print("DeviceOp: Arch err\n"); - return false; - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - print("DeviceOp: gfx 11 does not support fp8\n"); - return false; - } - } - - if constexpr(!(is_same_v || is_same_v)) - { - print("DeviceOp: Acc0 Type err\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: A layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: B layout must be Column\n"); - return false; - } - - if constexpr(!(is_same_v || - is_same_v)) - { - print("DeviceOp: B1 layout must be Column or Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: C layout must be Row\n"); - return false; - } - - // Other padding modes have not been tested and do not get checked individually. - if constexpr(GemmSpec != GemmSpecialization::Default && - GemmSpec != GemmSpecialization::MNKOPadding) - { - print("Padding mode must be default or MNKO\n"); - return false; - } - - // Per wmma dimensions not equal to 16 are very untested. - if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) - { - print("M, L, N per Wmma must be 16\n"); - return false; - } - - if(!GridwiseOp::CheckValidity(arg.a_grid_desc, - arg.b0_grid_desc, - Tuple<>{}, - arg.b1_grid_desc, - Tuple<>{}, - arg.c_grid_desc_m_n, - arg.block_2_ctile_map)) - { - return false; - } - - // Check scalar per vector requirement - const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; - const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; - const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; - const auto c_extent_lowest = arg.O; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - print("DeviceOp: Data Transfer Vector scalar err\n"); - return false; - } - - // Check vector load/store requirement - const auto a_stride_lowest = - ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; - const auto b0_stride_lowest = - B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; - const auto b1_stride_lowest = - B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; - const auto c_stride_lowest = arg.c_g_m_o_strides[2]; - - if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || - c_stride_lowest == 1)) - { - print("DeviceOp: Data Vectorize transfer err\n"); - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) - { - return false; - } - - return true; + return DeviceGemmGemmCommon::IsSupportedArgument(arg); } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::RawArg; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); - const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); - - const index_t grid_size = arg.batch_count * M0 * N0; - - auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { - constexpr bool has_loop = decltype(has_main_k_block_loop)::value; - constexpr TailNumber tn = tail_number; - - const auto kernel = - kernel_batched_gemm_gemm_wmma_cshuffle_v3; - - return launch_and_time_kernel( - stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); - }; - - bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); - TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); - - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); - return 0.0f; - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); - return 0.0f; - } - } - else - { - printf("Invalid pipeline version!\n"); - return 0.0f; - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - // polymorphic std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b0, @@ -669,28 +300,39 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm(static_cast(p_a), - static_cast(p_b0), - static_cast(p_b1), - static_cast(p_c), - M, - N, - K, - O, - Batch, - StrideA, - StrideB0, - StrideB1, - StrideC, - BatchStrideA, - BatchStrideB0, - BatchStrideB1, - BatchStrideC, - a_element_op, - b0_element_op, - acc_element_op, - b1_element_op, - c_element_op); + + std::array p_d0_grid{}; + std::array p_d1_grid{}; + std::array StrideD0s{}, BatchStrideD0s{}; + std::array StrideD1s, BatchStrideD1s{}; + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + p_d0_grid, + static_cast(p_b1), + p_d1_grid, + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideC, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); } static auto MakeInvoker() { return Invoker{}; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 0000000000..a739af898f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,902 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/integral_constant.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::Argument arg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_e1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCE1BasePtr(g_idx))); + + auto [p_d0s_grid, p_d1s_grid] = [&]() { + if constexpr(IsMultiD) + { + auto create_grid = [](auto NumTensor, auto func, auto& arg_grid, auto&& grid_pointer) { + static_for<0, decltype(NumTensor)::value, 1>{}([&](auto In) { + const long_index_t batch_offset = __builtin_amdgcn_readfirstlane(func(In)); + grid_pointer(In) = arg_grid(In) + batch_offset; + }); + return std::move(grid_pointer); + }; + auto get_d0_base_ptr = [&arg, &g_idx](auto d_idx) { + return arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, d_idx); + }; + auto get_d1_base_ptr = [&arg, &g_idx](auto d_idx) { + return arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, d_idx); + }; + auto d0s_grid = create_grid(ck::integral_constant{}, + get_d0_base_ptr, + arg.p_d0s_grid, + GridwiseOp::MakeD0sGridPointer()); + auto d1s_grid = create_grid(ck::integral_constant{}, + get_d1_base_ptr, + arg.p_d1s_grid, + GridwiseOp::MakeD1sGridPointer()); + return std::make_pair(d0s_grid, d1s_grid); + } + else + { + return std::make_pair(Tuple<>{}, Tuple<>{}); + } + }(); + + GridwiseOp::template Run( + arg.p_a_grid + a_batch_offset, + arg.p_b0_grid + b0_batch_offset, + p_d0s_grid, + arg.p_b1_grid + b1_batch_offset, + p_d1s_grid, + arg.p_c_e1_grid + c_e1_batch_offset, + p_shared, + arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock, + arg.c_e1_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op, + arg.b0_element_op, + arg.acc_element_op, + arg.b1_element_op, + arg.cde1_element_op, + arg.block_2_etile_map); +#else + ignore = arg; +#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) +} + +template +struct DeviceGemmGemm_Wmma_CShuffleV3_Common +{ + static constexpr ck::index_t NumD0Tensor = []() { + if constexpr(IsMultiD) + { + return DeviceOp::NumD0Tensor; + } + return 0; + }(); + static constexpr ck::index_t NumD1Tensor = []() { + if constexpr(IsMultiD) + { + return DeviceOp::NumD1Tensor; + } + return 0; + }(); + + struct GridDescriptorCreator + { + // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler + // Transform operator or just not use one at all. + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence<1, 1, 1, 1, 1>, + Sequence, + GemmSpec, + TensorSpecialization::Default, // ASpec + TensorSpecialization::Default, // B0Spec + TensorSpecialization::Default, // B1Spec + TensorSpecialization::Default>; // CSpec + + __host__ __device__ static auto + MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, + const std::array& a_g_m_k_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, + const std::array& b0_g_l_k_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, + const std::array& b1_g_n_l_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeD0GridDescriptor(const std::array& d0_g_m_n_lengths_vec, + const std::array& d0_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec); + } + + __host__ __device__ static auto MakeD0sGridDescriptor( + const std::array, NumD0Tensor>& d0_g_m_n_lengths_vec, + const std::array, NumD0Tensor>& d0_g_m_n_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto MakeD1sGridDescriptor( + const std::array, NumD1Tensor>& d1_g_m_o_lengths_vec, + const std::array, NumD1Tensor>& d1_g_m_o_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto + MakeE1GridDescriptor(const std::array& e1_g_m_n_lengths_vec, + const std::array& e1_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec); + } + }; + + using AGridDesc = decltype(GridDescriptorCreator::MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(GridDescriptorCreator::MakeB0GridDescriptor({}, {})); + using D0sGridDesc = + remove_cvref_t; + using B1GridDesc = decltype(GridDescriptorCreator::MakeB1GridDescriptor({}, {})); + using D1sGridDesc = + remove_cvref_t; + using E1GridDesc = decltype(GridDescriptorCreator::MakeE1GridDescriptor({}, {})); + using CGridDesc_M_N = + decltype(GridDescriptorCreator::Transform::MakeCGridDescriptor_M_N({}, {})); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB0, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB0_(BatchStrideB0), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_E1_(BatchStrideC) + { + } + + ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1) + : BatchStrideA_(BatchStrideA0), + BatchStrideB0_(BatchStrideB0), + BatchStrideD0s_(BatchStrideD0s), + BatchStrideB1_(BatchStrideB1), + BatchStrideD1s_(BatchStrideD1s), + BatchStrideC_E1_(BatchStrideE1) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB0_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCE1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_E1_); + } + + template + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, + Number d0_idx) const + { + return g_idx * static_cast(BatchStrideD0s_[d0_idx]); + } + + template + __host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx, + Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD1s_[d1_idx]); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB0_; + std::array BatchStrideD0s_; + index_t BatchStrideB1_; + std::array BatchStrideD1s_; + index_t BatchStrideC_E1_; + }; +}; + +template +struct DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg +{ + using GridwiseGemm = typename DeviceOp::GridwiseOp; + using Common = + DeviceGemmGemm_Wmma_CShuffleV3_Common; + + static constexpr auto NumD0Tensor = Common::NumD0Tensor; + static constexpr auto NumD1Tensor = Common::NumD1Tensor; + + struct Argument : public BaseArgument + { + using arr3 = std::array; + + Argument(const ADataType* p_a_grid_, + const B0DataType* p_b0_grid_, + std::array p_d0s_grid_, + const B1DataType* p_b1_grid_, + std::array p_d1s_grid_, + CE1DataType* p_e1_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t O_, + index_t Batch, + index_t StrideA, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + AElementwiseOperation a_element_op_, + B0ElementwiseOperation b0_element_op_, + AccElementwiseOperation acc_element_op_, + B1ElementwiseOperation b1_element_op_, + CDE1ElementwiseOperation cde1_element_op_) + : p_a_grid{p_a_grid_}, + p_b0_grid{p_b0_grid_}, + p_d0s_grid{}, + p_b1_grid{p_b1_grid_}, + p_d1s_grid{}, + p_c_e1_grid{p_e1_grid_}, + M{M_}, + N{N_}, + K{K_}, + O{O_}, + batch_count{Batch}, + a_element_op{a_element_op_}, + b0_element_op{b0_element_op_}, + acc_element_op{acc_element_op_}, + b1_element_op{b1_element_op_}, + cde1_element_op{cde1_element_op_}, + compute_base_ptr_of_batch{BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1} + { + + a_g_m_k_lengths = arr3{batch_count, M, K}; + a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] + + b0_g_n_k_lengths = arr3{batch_count, N, K}; + b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] + + b1_g_o_n_lengths = arr3{batch_count, O, N}; + b1_g_o_n_strides = + is_same_v + ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] + : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] + + e1_g_m_o_lengths = arr3{batch_count, M, O}; + e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O] + + a_grid_desc = Common::GridDescriptorCreator::MakeAGridDescriptor(a_g_m_k_lengths, + a_g_m_k_strides); + b0_grid_desc = Common::GridDescriptorCreator::MakeB0GridDescriptor(b0_g_n_k_lengths, + b0_g_n_k_strides); + b1_grid_desc = Common::GridDescriptorCreator::MakeB1GridDescriptor(b1_g_o_n_lengths, + b1_g_o_n_strides); + c_e1_grid_desc_m_n = Common::GridDescriptorCreator::MakeE1GridDescriptor( + e1_g_m_o_lengths, e1_g_m_o_strides); + c_e1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_e1_grid_desc_m_n); + block_2_etile_map = GridwiseGemm::MakeDefaultBlock2ETileMap(c_e1_grid_desc_m_n, 1, 1); + + if constexpr(IsMultiD) + { + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + using D0DataType = remove_cvref_t>; + + // D0s layout [batch_count, M, N] + d0s_g_m_n_lengths[i] = arr3{batch_count, M, N}; + d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1}; + + // D0 pointer + p_d0s_grid(i) = static_cast(p_d0s_grid_[i]); + }); + // D0 desc + d0s_grid_desc = Common::GridDescriptorCreator::MakeD0sGridDescriptor( + d0s_g_m_n_lengths, d0s_g_m_n_strides); + + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + using D1DataType = remove_cvref_t>; + + // D1s layout [batch_count, M, O] + d1s_g_m_o_lengths[i] = arr3{batch_count, M, O}; + d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1}; + + // D1 pointer + p_d1s_grid(i) = static_cast(p_d1s_grid_[i]); + }); + // D1 desc + d1s_grid_desc = Common::GridDescriptorCreator::MakeD1sGridDescriptor( + d1s_g_m_o_lengths, d1s_g_m_o_strides); + + d1s_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d1s_grid_desc); + } + } + + // Pointers + const ADataType* p_a_grid; + const B0DataType* p_b0_grid; + typename GridwiseGemm::D0sGridPointer p_d0s_grid; + const B1DataType* p_b1_grid; + typename GridwiseGemm::D1sGridPointer p_d1s_grid; + CE1DataType* p_c_e1_grid; + + // Raw Problem Size + index_t M; + index_t N; + index_t K; + index_t O; + index_t batch_count; + + arr3 a_g_m_k_lengths; + arr3 a_g_m_k_strides; + arr3 b0_g_n_k_lengths; + arr3 b0_g_n_k_strides; + std::array d0s_g_m_n_lengths; + std::array d0s_g_m_n_strides; + arr3 b1_g_o_n_lengths; + arr3 b1_g_o_n_strides; + std::array d1s_g_m_o_lengths; + std::array d1s_g_m_o_strides; + arr3 e1_g_m_o_lengths; + arr3 e1_g_m_o_strides; + + AElementwiseOperation a_element_op; + B0ElementwiseOperation b0_element_op; + AccElementwiseOperation acc_element_op; + B1ElementwiseOperation b1_element_op; + CDE1ElementwiseOperation cde1_element_op; + + // Grid descriptors and other mem calculators + typename Common::AGridDesc a_grid_desc; + typename Common::B0GridDesc b0_grid_desc; + std::conditional_t> d0s_grid_desc; + typename Common::B1GridDesc b1_grid_desc; + typename Common::D1sGridDesc d1s_grid_desc; + std::conditional_t< + IsMultiD, + typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + Tuple<>> + d1s_grid_desc_mblock_mperblock_nblock_nperblock; + + std::conditional_t + c_e1_grid_desc_m_n; + typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_e1_grid_desc_mblock_mperblock_nblock_nperblock; + + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map; + + typename Common::ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; + }; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); + + const index_t grid_size = arg.batch_count * M0 * N0; + + auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { + constexpr bool has_loop = decltype(has_main_k_block_loop)::value; + constexpr TailNumber tail_num = decltype(tail_number)::value; + const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3; + return launch_and_time_kernel( + stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); + }; + + bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); + return 0.0f; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); + return 0.0f; + } + } + else + { + printf("Invalid pipeline version!\n"); + return 0.0f; + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + // check if DsLayout is supported + template + static constexpr bool CheckDLayout() + { + bool valid = true; + // iterate over DLayout tuple + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // if RefLayout and DLayout are same, keep valid true, otherwise false + valid = valid && is_same_v; + }); + return valid; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // Print lambda with env check and printf() style formmating. + const char* curFunc = __func__; + auto print = [&curFunc](const char* format, ...) -> void { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#endif + va_list args; + va_start(args, format); + std::vfprintf(stdout, format, args); + va_end(args); +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; + } + }; + + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + print("DeviceOp: Arch err\n"); + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + print("DeviceOp: gfx 11 does not support fp8\n"); + return false; + } + } + + if constexpr(!(is_same_v || is_same_v)) + { + print("DeviceOp: Acc0 Type err\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: A layout must be Row\n"); + return false; + } + + if constexpr(!(is_same_v || + is_same_v)) + { + print("DeviceOp: B1 layout must be Column or Row\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: C layout must be Row\n"); + return false; + } + + // Other padding modes have not been tested and do not get checked individually. + if constexpr(GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MNKOPadding) + { + print("Padding mode must be default or MNKO\n"); + return false; + } + + // Per wmma dimensions not equal to 16 are very untested. + if constexpr(MPerWmma != 16 || LPerWmma != 16 || DeviceOp::NPerWmma != 16) + { + print("M, L, N per Wmma must be 16\n"); + return false; + } + + if constexpr(IsMultiD) + { + if constexpr(!(is_same_v)) + { + print("DeviceOp: B0 layout must be Column\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D0s layout must be Row\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D1s layout must be Row\n"); + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc, + arg.c_e1_grid_desc_m_n, + arg.block_2_etile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto cde1_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2 + ? arg.b0_g_n_k_strides[2] + : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? arg.b1_g_o_n_strides[2] + : arg.b1_g_o_n_strides[1]; + const auto e1_stride_lowest = arg.e1_g_m_o_strides[2]; + + // NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major + // and the lowest dimension stride is hardcoded to 1 + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + e1_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + } + else + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + Tuple<>{}, + arg.b1_grid_desc, + Tuple<>{}, + arg.c_e1_grid_desc_m_n, + arg.block_2_etile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto c_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2 + ? arg.b0_g_n_k_strides[2] + : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? arg.b1_g_o_n_strides[2] + : arg.b1_g_o_n_strides[1]; + const auto c_stride_lowest = arg.e1_g_m_o_strides[2]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) + { + return false; + } + + return true; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp index 06651c0c0e..83fec9c95f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -3,91 +3,20 @@ #pragma once -#include #include -#include #include #include #include "ck/ck.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b0_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); - const long_index_t b1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t e1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx))); - - auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer(); - auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer(); - - static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) { - const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); - p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset; - }); - - static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) { - const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); - p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset; - }); - - GridwiseOp::template Run( - arg.p_a_grid + a_batch_offset, - arg.p_b0_grid + b0_batch_offset, - p_d0s_grid, - arg.p_b1_grid + b1_batch_offset, - p_d1s_grid, - arg.p_e1_grid + e1_batch_offset, - p_shared, - arg.a_grid_desc, - arg.b0_grid_desc, - arg.d0s_grid_desc, - arg.b1_grid_desc, - arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e1_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op, - arg.b0_element_op, - arg.acc_element_op, - arg.b1_element_op, - arg.cde1_element_op, - arg.block_2_etile_map); -#else - ignore = arg; -#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) -} - // Computes: // Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...) // E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...) @@ -184,151 +113,51 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size(); - static constexpr auto I0 = Number<0>{}; - // To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal // to LPerWmma (A.k.a Gemm0 NPerWmma). static constexpr index_t NPerWmma = LPerWmma; - // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler - // Transform operator or just not use one at all. - using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< - Sequence<1, 1, 1, 1, 1>, - Sequence, - GemmSpec, - TensorSpecialization::Default, // ASpec - TensorSpecialization::Default, // B0Spec - TensorSpecialization::Default, // B1Spec - TensorSpecialization::Default>; // CSpec - - __host__ __device__ static auto - MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, - const std::array& a_g_m_k_strides_vec) - { - return Transform::MakeAGridDescriptor_AK0_M_AK1( - Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, - const std::array& b0_g_l_k_strides_vec) - { - return Transform::MakeB0GridDescriptor_BK0_N_BK1( - Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, - const std::array& b1_g_n_l_strides_vec) - { - return Transform::MakeB1GridDescriptor_BK0_N_BK1( - Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeD0GridDescriptor(const std::array& d0_g_m_n_lengths_vec, - const std::array& d0_g_m_n_strides_vec) - { - return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec); - } - - __host__ __device__ static auto MakeD0sGridDescriptor( - const std::array, NumD0Tensor>& d0_g_m_n_lengths_vec, - const std::array, NumD0Tensor>& d0_g_m_n_strides_vec) - { - return generate_tuple( - [&](auto i) { - return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]); - }, - Number{}); - } - - __host__ __device__ static auto MakeD1sGridDescriptor( - const std::array, NumD0Tensor>& d1_g_m_o_lengths_vec, - const std::array, NumD0Tensor>& d1_g_m_o_strides_vec) - { - return generate_tuple( - [&](auto i) { - return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]); - }, - Number{}); - } - - __host__ __device__ static auto - MakeE1GridDescriptor(const std::array& e1_g_m_n_lengths_vec, - const std::array& e1_g_m_n_strides_vec) - { - return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec); - } - - using AGridDesc = decltype(MakeAGridDescriptor({}, {})); - using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); - using D0sGridDesc = remove_cvref_t; - using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); - using D1sGridDesc = remove_cvref_t; - using E1GridDesc = decltype(MakeE1GridDescriptor({}, {})); - - struct ComputeBasePtrOfStridedBatch - { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, - index_t BatchStrideB0, - std::array BatchStrideD0s, - index_t BatchStrideB1, - std::array BatchStrideD1s, - index_t BatchStrideE1) - : BatchStrideA0_(BatchStrideA0), - BatchStrideB0_(BatchStrideB0), - BatchStrideD0s_(BatchStrideD0s), - BatchStrideB1_(BatchStrideB1), - BatchStrideD1s_(BatchStrideD1s), - BatchStrideE1_(BatchStrideE1) - { - } - - __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA0_); - } - - __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB0_); - } - - template - __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, - Number d1_idx) const - { - return g_idx * static_cast(BatchStrideD0s_[d1_idx]); - } - - __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB1_); - } - - __host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE1_); - } - - template - __host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number d1_idx) const - { - return g_idx * static_cast(BatchStrideD1s_[d1_idx]); - } - - private: - index_t BatchStrideA0_; - index_t BatchStrideB0_; - std::array BatchStrideD0s_; - index_t BatchStrideB1_; - std::array BatchStrideD1s_; - index_t BatchStrideE1_; - }; + using DeviceGemmGemmCommonBase = + DeviceGemmGemm_Wmma_CShuffleV3_Common; // IsMultiD // GridwiseOp using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< @@ -350,12 +179,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, // InMemory Data Descriptor - AGridDesc, - B0GridDesc, - D0sGridDesc, - B1GridDesc, - D1sGridDesc, - E1GridDesc, + typename DeviceGemmGemmCommonBase::AGridDesc, + typename DeviceGemmGemmCommonBase::B0GridDesc, + typename DeviceGemmGemmCommonBase::D0sGridDesc, + typename DeviceGemmGemmCommonBase::B1GridDesc, + typename DeviceGemmGemmCommonBase::D1sGridDesc, + typename DeviceGemmGemmCommonBase::E1GridDesc, // Tiling Family MPerBlock, LPerBlock, @@ -402,430 +231,67 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, - Transform::matrix_padder.PadN, + DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer>; - struct RawArg : public BaseArgument + using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg< + DeviceOp, + GemmSpec, + ALayout, + B0layout, + D0sLayout, + B1Layout, + D1sLayout, + E1Layout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + E1DataType, + D0sDataType, + D1sDataType, + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CDE1ElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + CDE0BlockTransferSrcScalarPerVector, + CShuffleBlockTransferScalarPerVector_NPerBlock, + true>; // IsMultiD + // Invoker + using Invoker = typename DeviceGemmGemmCommon::Invoker; + + // Argument + using Argument = typename DeviceGemmGemmCommon::Argument; + + static bool IsSupportedArgument(const Argument& arg) { - using arr3 = std::array; - - RawArg(const ADataType* p_a_grid_, - const B0DataType* p_b0_grid_, - std::array p_d0s_grid_, - const B1DataType* p_b1_grid_, - std::array p_d1s_grid_, - E1DataType* p_e1_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t O_, - index_t Batch, - index_t StrideA, - index_t StrideB0, - std::array StrideD0s, - index_t StrideB1, - std::array StrideD1s, - index_t StrideE1, - index_t BatchStrideA, - index_t BatchStrideB0, - std::array BatchStrideD0s, - index_t BatchStrideB1, - std::array BatchStrideD1s, - index_t BatchStrideE1, - AElementwiseOperation a_element_op_, - B0ElementwiseOperation b0_element_op_, - AccElementwiseOperation acc_element_op_, - B1ElementwiseOperation b1_element_op_, - CDE1ElementwiseOperation cde1_element_op_) - : p_a_grid{p_a_grid_}, - p_b0_grid{p_b0_grid_}, - p_d0s_grid{}, - p_b1_grid{p_b1_grid_}, - p_d1s_grid{}, - p_e1_grid{p_e1_grid_}, - M{M_}, - N{N_}, - K{K_}, - O{O_}, - batch_count{Batch}, - a_element_op{a_element_op_}, - b0_element_op{b0_element_op_}, - acc_element_op{acc_element_op_}, - b1_element_op{b1_element_op_}, - cde1_element_op{cde1_element_op_}, - compute_base_ptr_of_batch{BatchStrideA, - BatchStrideB0, - BatchStrideD0s, - BatchStrideB1, - BatchStrideD1s, - BatchStrideE1} - { - - a_g_m_k_lengths = arr3{batch_count, M, K}; - a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] - - b0_g_n_k_lengths = arr3{batch_count, N, K}; - b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] - - b1_g_o_n_lengths = arr3{batch_count, O, N}; - b1_g_o_n_strides = - is_same_v - ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] - : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] - - e1_g_m_o_lengths = arr3{batch_count, M, O}; - e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O] - - a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); - b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); - b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); - e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides); - e1_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e1_grid_desc_m_n); - block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1); - - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - using D0DataType = remove_cvref_t>; - - // D0s layout [batch_count, M, N] - d0s_g_m_n_lengths[i] = arr3{batch_count, M, N}; - d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1}; - - // D0 pointer - p_d0s_grid(i) = static_cast(p_d0s_grid_[i]); - - // D0 desc - d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]); - }); - - static_for<0, NumD1Tensor, 1>{}([&](auto i) { - using D1DataType = remove_cvref_t>; - - // D1s layout [batch_count, M, O] - d1s_g_m_o_lengths[i] = arr3{batch_count, M, O}; - d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1}; - - // D1 pointer - p_d1s_grid(i) = static_cast(p_d1s_grid_[i]); - - // D1 desc - d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]); - }); - - d1s_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc); - } - - // Pointers - const ADataType* p_a_grid; - const B0DataType* p_b0_grid; - typename GridwiseOp::D0sGridPointer p_d0s_grid; - const B1DataType* p_b1_grid; - typename GridwiseOp::D1sGridPointer p_d1s_grid; - E1DataType* p_e1_grid; - - // Raw Problem Size - index_t M; - index_t N; - index_t K; - index_t O; - index_t batch_count; - - arr3 a_g_m_k_lengths; - arr3 a_g_m_k_strides; - arr3 b0_g_n_k_lengths; - arr3 b0_g_n_k_strides; - std::array d0s_g_m_n_lengths; - std::array d0s_g_m_n_strides; - arr3 b1_g_o_n_lengths; - arr3 b1_g_o_n_strides; - std::array d1s_g_m_o_lengths; - std::array d1s_g_m_o_strides; - arr3 e1_g_m_o_lengths; - arr3 e1_g_m_o_strides; - - AElementwiseOperation a_element_op; - B0ElementwiseOperation b0_element_op; - AccElementwiseOperation acc_element_op; - B1ElementwiseOperation b1_element_op; - CDE1ElementwiseOperation cde1_element_op; - - // Grid descriptors and other mem calculators - AGridDesc a_grid_desc; - B0GridDesc b0_grid_desc; - D0sGridDesc d0s_grid_desc; - B1GridDesc b1_grid_desc; - D1sGridDesc d1s_grid_desc; - typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - d1s_grid_desc_mblock_mperblock_nblock_nperblock; - - E1GridDesc e1_grid_desc_m_n; - typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e1_grid_desc_mblock_mperblock_nblock_nperblock; - - typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map; - - ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; - }; - - // check if DsLayout is supported - template - static constexpr bool CheckDLayout() - { - bool valid = true; - // iterate over DLayout tuple - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - // if RefLayout and DLayout are same, keep valid true, otherwise false - valid = valid && is_same_v; - }); - return valid; + return DeviceGemmGemmCommon::IsSupportedArgument(arg); } - - static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) - { - // Print lambda with env check and printf() style formmating. - const char* curFunc = __func__; - auto print = [&curFunc](const char* format, ...) -> void { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wformat-nonliteral" -#endif - va_list args; - va_start(args, format); - std::vfprintf(stdout, format, args); - va_end(args); -#if defined(__clang__) -#pragma clang diagnostic pop -#endif - std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; - } - }; - - if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - print("DeviceOp: Arch err\n"); - return false; - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - print("DeviceOp: gfx 11 does not support fp8\n"); - return false; - } - } - - if constexpr(!(is_same_v || is_same_v)) - { - print("DeviceOp: Acc0 Type err\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: A layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: B0 layout must be Column\n"); - return false; - } - - if constexpr(!(CheckDLayout())) - { - print("DeviceOp: All D0s layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v || - is_same_v)) - { - print("DeviceOp: B1 layout must be Column or Row\n"); - return false; - } - - if constexpr(!(CheckDLayout())) - { - print("DeviceOp: All D1s layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: C layout must be Row\n"); - return false; - } - - // Other padding modes have not been tested and do not get checked individually. - if constexpr(GemmSpec != GemmSpecialization::Default && - GemmSpec != GemmSpecialization::MNKOPadding) - { - print("Padding mode must be default or MNKO\n"); - return false; - } - - // Per wmma dimensions not equal to 16 are very untested. - if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) - { - print("M, L, N per Wmma must be 16\n"); - return false; - } - - if(!GridwiseOp::CheckValidity(arg.a_grid_desc, - arg.b0_grid_desc, - arg.d0s_grid_desc, - arg.b1_grid_desc, - arg.d1s_grid_desc, - arg.e1_grid_desc_m_n, - arg.block_2_etile_map)) - { - return false; - } - - // Check scalar per vector requirement - const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; - const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; - const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major - const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; - const auto cde1_extent_lowest = arg.O; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && - cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - print("DeviceOp: Data Transfer Vector scalar err\n"); - return false; - } - - // Check vector load/store requirement - const auto a_stride_lowest = - ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; - const auto b0_stride_lowest = - B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; - const auto b1_stride_lowest = - B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; - const auto e1_stride_lowest = arg.e1_g_m_o_strides[2]; - - // NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major - // and the lowest dimension stride is hardcoded to 1 - - if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || - e1_stride_lowest == 1)) - { - print("DeviceOp: Data Vectorize transfer err\n"); - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) - { - return false; - } - - return true; - } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::RawArg; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); - const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); - - const index_t grid_size = arg.batch_count * M0 * N0; - - auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { - constexpr bool has_loop = decltype(has_main_k_block_loop)::value; - constexpr TailNumber tn = tail_number; - - const auto kernel = - kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3; - - return launch_and_time_kernel( - stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); - }; - - bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); - TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); - - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); - return 0.0f; - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); - return 0.0f; - } - } - else - { - printf("Invalid pipeline version!\n"); - return 0.0f; - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - static auto MakeArgument(const ADataType* p_a0, const B0DataType* p_b0, std::array p_d0s, @@ -855,20 +321,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op) { - return RawArg{p_a0, p_b0, - p_d0s, p_b1, - p_d1s, p_e1, - MRaw, NRaw, - KRaw, Gemm1NRaw, - Batch, StrideA0, - StrideB0, StrideD0s, - StrideB1, StrideD1s, - StrideE1, BatchStrideA0, - BatchStrideB0, BatchStrideD0s, - BatchStrideB1, BatchStrideD1s, - BatchStrideE1, a0_element_op, - b0_element_op, cde0_element_op, - b1_element_op, cde1_element_op}; + return Argument{p_a0, p_b0, + p_d0s, p_b1, + p_d1s, p_e1, + MRaw, NRaw, + KRaw, Gemm1NRaw, + Batch, StrideA0, + StrideB0, StrideD0s, + StrideB1, StrideD1s, + StrideE1, BatchStrideA0, + BatchStrideB0, BatchStrideD0s, + BatchStrideB1, BatchStrideD1s, + BatchStrideE1, a0_element_op, + b0_element_op, cde0_element_op, + b1_element_op, cde1_element_op}; } // polymorphic @@ -902,34 +368,34 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation c_element_op) override { - return std::make_unique(static_cast(p_a), - static_cast(p_b0), - p_d0s, - static_cast(p_b1), - p_d1s, - static_cast(p_c), - M, - N, - K, - O, - Batch, - StrideA, - StrideB0, - StrideD0s, - StrideB1, - StrideD1s, - StrideE1, - BatchStrideA, - BatchStrideB0, - BatchStrideD0s, - BatchStrideB1, - BatchStrideD1s, - BatchStrideE1, - a_element_op, - b0_element_op, - acc_element_op, - b1_element_op, - c_element_op); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + p_d0s, + static_cast(p_b1), + p_d1s, + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideE1, + BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); } static auto MakeInvoker() { return Invoker{}; } From 834642202c0cb39df1b96dacc24d5c3b3d97e62c Mon Sep 17 00:00:00 2001 From: SamiAario-AMD Date: Mon, 26 Jan 2026 20:23:26 +0200 Subject: [PATCH 24/42] Re enable f8 x bf8 tests on compv3 and compv4 (#3605) * Re-enable f8 x bf8 tests on CompV3 as they now pass * On CompV4, fp8 x bf8 tests now pass with K_BlockSize I32 * Add a changelog entry --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- CHANGELOG.md | 1 + test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp | 9 ++------- test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp | 8 ++++---- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f17a4d768..c99fc1d065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4 * Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp index ebe17aadd6..016f7be60d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp @@ -13,13 +13,8 @@ class TestCkTileGemmPipelineCompV3 static constexpr bool check_data_type() { using Base = TestCkTileGemmPipeline>; - if constexpr(std::is_same_v && - std::is_same_v) - { - return false; - } - else if constexpr(std::is_same_v && - std::is_same_v) + if constexpr(std::is_same_v && + std::is_same_v) { return false; } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 334e360eb5..4bef581254 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -170,7 +170,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -180,7 +180,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -190,7 +190,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -200,7 +200,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> From 3900e1e7ceacfa32cb8d1522260ed30befd4dae3 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 26 Jan 2026 10:29:28 -0800 Subject: [PATCH 25/42] Solve the CTAD regression & add up the Shell file for the docker management in testing (#3634) * Finished the work * Fix the pipeline --- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 2 +- .../ck_tile/ops/reduce/block/block_reduce.hpp | 4 - .../ops/softmax/block/block_softmax_2d.hpp | 2 +- script/tools/ck-build | 143 +++++++++++++++ script/tools/ck-clean | 113 ++++++++++++ script/tools/ck-exec | 111 ++++++++++++ script/tools/ck-logs | 134 ++++++++++++++ script/tools/ck-shell | 84 +++++++++ script/tools/ck-start | 103 +++++++++++ script/tools/ck-status | 153 ++++++++++++++++ script/tools/ck-stop | 141 +++++++++++++++ script/tools/ck-test | 166 ++++++++++++++++++ 12 files changed, 1150 insertions(+), 6 deletions(-) create mode 100755 script/tools/ck-build create mode 100755 script/tools/ck-clean create mode 100755 script/tools/ck-exec create mode 100755 script/tools/ck-logs create mode 100755 script/tools/ck-shell create mode 100755 script/tools/ck-start create mode 100755 script/tools/ck-status create mode 100755 script/tools/ck-stop create mode 100755 script/tools/ck-test diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index c4ab1d4a78..34d18cb8e1 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -227,7 +227,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<1>>{}); else return make_static_tile_distribution( - tile_distribution_encoding< // + tile_distribution_encoding< sequence, tuple, sequence>, diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 4284e7622f..3f59e2d036 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -392,8 +392,4 @@ struct BlockReduce2D InDataType reduce_init; }; -// deduction guide -template -CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D; - } // namespace ck_tile diff --git a/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp index abb95934ff..58e768b319 100644 --- a/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp +++ b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp @@ -40,7 +40,7 @@ struct BlockSoftmax2D #endif // compute row max - auto reduce_row_max = BlockReduce2D{x, -numeric::infinity()}; + auto reduce_row_max = BlockReduce2D{x, -numeric::infinity()}; #if _BLOCK_SOFTMAX_USE_UNPACK2 auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{}); #else diff --git a/script/tools/ck-build b/script/tools/ck-build new file mode 100755 index 0000000000..2c0bb24eda --- /dev/null +++ b/script/tools/ck-build @@ -0,0 +1,143 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Build - Build Composable Kernel targets in Docker + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Build - Build Composable Kernel targets in Docker + +Usage: ck-build [options] [target...] + +Options: + -h, --help Show this help message + --name Specify container name + --reconfigure Reconfigure CMake before building + -j Parallel jobs (passed to ninja) + --clean Clean before building + +Arguments: + target Target(s) to build (default: all) + +Environment: + CK_CONTAINER_NAME - Override default container name + GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + +Examples: + ck-build # Build all targets + ck-build test_amdgcn_mma # Build specific target + ck-build test_amdgcn_mma test_gemm # Build multiple targets + ck-build --reconfigure # Reconfigure CMake and build all + ck-build --clean test_amdgcn_mma # Clean and build target + ck-build -j 8 test_amdgcn_mma # Build with 8 parallel jobs + +EOF +} + +# Parse arguments +targets=() +reconfigure=false +clean=false +parallel_jobs="" + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --reconfigure) + reconfigure=true + shift + ;; + --clean) + clean=true + shift + ;; + -j) + parallel_jobs="-j $2" + shift 2 + ;; + *) + targets+=("$1") + shift + ;; + esac +done + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Configure CMake if needed or requested +if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Detecting GPU target..." + GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") + + if [ "$reconfigure" = true ]; then + echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" + else + echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" + fi + + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace || exit 1 + rm -rf /workspace/build + mkdir /workspace/build + cd /workspace/build || exit 1 + cmake .. -GNinja \ + -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -DBUILD_TESTING=ON 2>&1 | tail -30 + " + echo "" +fi + +# Clean if requested +if [ "$clean" = true ]; then + echo "Cleaning build directory..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja clean + " + echo "" +fi + +# Build targets +if [ ${#targets[@]} -eq 0 ]; then + echo "Building all configured targets..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${parallel_jobs} 2>&1 + " +else + echo "Building targets: ${targets[*]}" + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${parallel_jobs} ${targets[*]} 2>&1 + " +fi + +echo "" +echo "Build complete ✓" diff --git a/script/tools/ck-clean b/script/tools/ck-clean new file mode 100755 index 0000000000..4b422f81f4 --- /dev/null +++ b/script/tools/ck-clean @@ -0,0 +1,113 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Clean - Clean build artifacts in Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Clean - Clean build artifacts in Docker container + +Usage: ck-clean [options] + +Options: + -h, --help Show this help message + --name Specify container name + --all Remove entire build directory + -f, --force Force without confirmation + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-clean # Clean build artifacts (ninja clean) + ck-clean --all # Remove entire build directory + ck-clean --force --all # Remove build directory without confirmation + +EOF +} + +# Parse arguments +remove_all=false +force=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --all) + remove_all=true + shift + ;; + -f|--force) + force=true + shift + ;; + *) + echo "Unknown option: $1" + show_help + exit 1 + ;; + esac +done + +# Check if container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running" + echo "Start with: ck-start" + exit 1 +fi + +# Check if build directory exists +if ! docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then + echo "Build directory does not exist" + exit 0 +fi + +if [ "$remove_all" = true ]; then + # Remove entire build directory + if [ "$force" = false ]; then + read -p "Remove entire build directory? (y/N) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Cancelled" + exit 0 + fi + fi + + echo "Removing build directory..." + docker exec "${CONTAINER_NAME}" bash -c "rm -rf /workspace/build" + echo "Build directory removed ✓" +else + # Clean with ninja + if ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Build not configured (build.ninja not found)" + echo "Use --all to remove build directory" + exit 1 + fi + + echo "Cleaning build artifacts..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja clean + " + echo "Build artifacts cleaned ✓" +fi diff --git a/script/tools/ck-exec b/script/tools/ck-exec new file mode 100755 index 0000000000..dfc7655774 --- /dev/null +++ b/script/tools/ck-exec @@ -0,0 +1,111 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Exec - Execute arbitrary commands in Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Exec - Execute arbitrary commands in Docker container + +Usage: ck-exec [options] [args...] + +Options: + -h, --help Show this help message + --name Specify container name + -w Working directory (default: /workspace) + -i, --interactive Interactive mode (allocate TTY) + +Arguments: + command Command to execute (required) + args Arguments to the command + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-exec rocm-smi # Run rocm-smi + ck-exec rocminfo # Run rocminfo + ck-exec ls -la build/bin # List build binaries + ck-exec -w /workspace/build ninja -t commands # Run ninja commands + ck-exec --interactive python3 # Interactive Python session + +Common Commands: + ck-exec rocm-smi # Check GPU status + ck-exec rocminfo \| grep gfx # Check GPU architecture + ck-exec hipcc --version # Check HIP compiler version + ck-exec cmake --version # Check CMake version + ck-exec ninja -C build -t targets # List all build targets + +EOF +} + +# Parse arguments +workdir="/workspace" +interactive=false +command_args=() + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + -w) + workdir="$2" + shift 2 + ;; + -i|--interactive) + interactive=true + shift + ;; + *) + command_args+=("$1") + shift + ;; + esac +done + +# Validate command +if [ ${#command_args[@]} -eq 0 ]; then + echo "Error: command required" + echo "" + show_help + exit 1 +fi + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Build command string +cmd_string="" +for arg in "${command_args[@]}"; do + cmd_string="${cmd_string} $(printf '%q' "$arg")" +done + +# Execute command +if [ "$interactive" = true ]; then + docker exec -it -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}" +else + docker exec -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}" +fi diff --git a/script/tools/ck-logs b/script/tools/ck-logs new file mode 100755 index 0000000000..cfad23b3b5 --- /dev/null +++ b/script/tools/ck-logs @@ -0,0 +1,134 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Logs - View container logs and build output + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Logs - View container logs and build output + +Usage: ck-logs [options] [container_name] + +Options: + -h, --help Show this help message + --name Specify container name + -f, --follow Follow log output + -n, --tail Show last N lines (default: 100) + --cmake Show CMake configuration log + --build Show last build log + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-logs # Show last 100 lines of container logs + ck-logs -f # Follow container logs + ck-logs -n 500 # Show last 500 lines + ck-logs --cmake # Show CMake configuration + ck-logs --build # Show build log + +EOF +} + +# Parse arguments +follow=false +tail_lines=100 +show_cmake=false +show_build=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + -f|--follow) + follow=true + shift + ;; + -n|--tail) + tail_lines="$2" + shift 2 + ;; + --cmake) + show_cmake=true + shift + ;; + --build) + show_build=true + shift + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Check if container exists +if ! container_exists "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' does not exist" + exit 1 +fi + +# Show CMake log +if [ "$show_cmake" = true ]; then + echo "CMake Configuration Log:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + if docker exec "${CONTAINER_NAME}" test -f /workspace/build/CMakeCache.txt 2>/dev/null; then + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build + echo 'GPU_TARGETS:' \$(grep 'GPU_TARGETS:' CMakeCache.txt | cut -d'=' -f2) + echo 'CMAKE_BUILD_TYPE:' \$(grep 'CMAKE_BUILD_TYPE:' CMakeCache.txt | cut -d'=' -f2) + echo 'CMAKE_CXX_COMPILER:' \$(grep 'CMAKE_CXX_COMPILER:' CMakeCache.txt | cut -d'=' -f2) + echo 'BUILD_TESTING:' \$(grep 'BUILD_TESTING:' CMakeCache.txt | cut -d'=' -f2) + " + else + echo "CMake not configured yet" + fi + exit 0 +fi + +# Show build log (last build output) +if [ "$show_build" = true ]; then + echo "Last Build Log:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + if docker exec "${CONTAINER_NAME}" test -f /workspace/build/.ninja_log 2>/dev/null; then + docker exec "${CONTAINER_NAME}" bash -c "tail -50 /workspace/build/.ninja_log" + else + echo "No build log found" + fi + exit 0 +fi + +# Show container logs +echo "Container Logs (${CONTAINER_NAME}):" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +if [ "$follow" = true ]; then + docker logs -f "${CONTAINER_NAME}" +else + docker logs --tail "${tail_lines}" "${CONTAINER_NAME}" +fi diff --git a/script/tools/ck-shell b/script/tools/ck-shell new file mode 100755 index 0000000000..785c9f4d68 --- /dev/null +++ b/script/tools/ck-shell @@ -0,0 +1,84 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Shell - Open interactive shell in Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Shell - Open interactive shell in Docker container + +Usage: ck-shell [options] [container_name] + +Options: + -h, --help Show this help message + --name Specify container name + -c Execute command instead of interactive shell + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-shell # Open interactive shell + ck-shell my_container # Open shell in specific container + ck-shell -c "rocm-smi" # Execute single command + ck-shell -c "cd build && ls bin" # Execute command in build directory + +EOF +} + +# Parse arguments +command="" + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + -c) + command="$2" + shift 2 + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Execute command or open shell +if [ -n "$command" ]; then + echo "Executing: ${command}" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker exec "${CONTAINER_NAME}" bash -c "${command}" +else + echo "Opening shell in '${CONTAINER_NAME}' (type 'exit' to leave)..." + docker exec -it "${CONTAINER_NAME}" bash +fi diff --git a/script/tools/ck-start b/script/tools/ck-start new file mode 100755 index 0000000000..f15477492a --- /dev/null +++ b/script/tools/ck-start @@ -0,0 +1,103 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Start - Start Docker container for Composable Kernel testing + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Start - Start Docker container for Composable Kernel testing + +Usage: ck-start [options] [container_name] + +Options: + -h, --help Show this help message + --image Specify Docker image (overrides CK_DOCKER_IMAGE) + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + CK_DOCKER_IMAGE - Override Docker image (default: rocm/composable_kernel:ck_ub24.04_rocm7.0.1) + +Examples: + ck-start # Start container with default name + ck-start my_ck_container # Start container with custom name + ck-start --image rocm/composable_kernel:latest + +EOF +} + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --image) + export CK_DOCKER_IMAGE="$2" + shift 2 + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Get Docker image +DOCKER_IMAGE=$(get_docker_image) + +# Check if container exists and is running +if container_exists "${CONTAINER_NAME}"; then + if container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' is already running" + docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd" + exit 0 + else + echo "Starting existing container '${CONTAINER_NAME}'..." + docker start "${CONTAINER_NAME}" + echo "Container started" + docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd" + exit 0 + fi +fi + +# Create new container +echo "Creating new Docker container '${CONTAINER_NAME}'..." +echo "Docker image: ${DOCKER_IMAGE}" +echo "Project root: ${PROJECT_ROOT}" +echo "" + +docker run -d \ + --name "${CONTAINER_NAME}" \ + --device=/dev/kfd --device=/dev/dri \ + --security-opt seccomp=unconfined \ + --group-add video \ + -v "${PROJECT_ROOT}":/workspace \ + -w /workspace \ + "${DOCKER_IMAGE}" \ + tail -f /dev/null + +echo "" +echo "Container '${CONTAINER_NAME}' started successfully" +docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd" + +# Show GPU info +echo "" +echo "GPU Information:" +docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -5 || echo 'No GPU detected'" diff --git a/script/tools/ck-status b/script/tools/ck-status new file mode 100755 index 0000000000..fea9de8c36 --- /dev/null +++ b/script/tools/ck-status @@ -0,0 +1,153 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Status - Check container status and information + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Status - Check container status and information + +Usage: ck-status [options] [container_name] + +Options: + -h, --help Show this help message + --name Specify container name + --all Show all CK containers + -v, --verbose Show detailed information + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-status # Check default container status + ck-status my_container # Check specific container + ck-status --all # Show all CK containers + ck-status -v # Show detailed information + +EOF +} + +# Parse arguments +show_all=false +verbose=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --all) + show_all=true + shift + ;; + -v|--verbose) + verbose=true + shift + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +DOCKER_IMAGE=$(get_docker_image) + +# Show all containers +if [ "$show_all" = true ]; then + echo "Composable Kernel Docker Containers:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + username=$(get_username) + containers=$(docker ps -a --filter "name=ck_${username}_" --format "table {{.Names}}\t{{.Status}}\t{{.CreatedAt}}" 2>/dev/null || echo "") + + if [ -z "$containers" ] || [ "$containers" = "NAMES STATUS CREATED AT" ]; then + echo "No CK containers found for user '${username}'" + else + echo "$containers" + fi + exit 0 +fi + +# Check specific container status +echo "Container: ${CONTAINER_NAME}" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +if container_is_running "${CONTAINER_NAME}"; then + echo "Status: RUNNING ✓" + echo "" + docker ps --filter "name=^${CONTAINER_NAME}$" --format "table {{.Names}}\t{{.Status}}\t{{.Image}}" + + if [ "$verbose" = true ]; then + echo "" + echo "Container Details:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker inspect "${CONTAINER_NAME}" --format ' +Image: {{.Config.Image}} +Created: {{.Created}} +Platform: {{.Platform}} +Mounts: {{range .Mounts}} + - {{.Source}} -> {{.Destination}}{{end}} +' + fi + + echo "" + echo "GPU Information:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -10 || echo 'No GPU detected'" + + if [ "$verbose" = true ]; then + echo "" + echo "GPU Target:" + gpu_target=$(detect_gpu_target "${CONTAINER_NAME}") + echo " ${gpu_target}" + + echo "" + echo "Build Status:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + if docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then + if docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo " CMake configured ✓" + echo " Build directory: /workspace/build" + + # Count built test binaries + bin_count=$(docker exec "${CONTAINER_NAME}" bash -c "ls -1 /workspace/build/bin 2>/dev/null | wc -l" || echo "0") + echo " Test binaries: ${bin_count}" + else + echo " CMake not configured" + fi + else + echo " Build directory not found" + fi + fi + +elif container_exists "${CONTAINER_NAME}"; then + echo "Status: STOPPED" + echo "" + echo "Start with: ck-start" +else + echo "Status: DOES NOT EXIST" + echo "" + echo "Create with: ck-start" +fi diff --git a/script/tools/ck-stop b/script/tools/ck-stop new file mode 100755 index 0000000000..b793f47408 --- /dev/null +++ b/script/tools/ck-stop @@ -0,0 +1,141 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Stop - Stop and remove Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Stop - Stop and remove Docker container + +Usage: ck-stop [options] [container_name] + +Options: + -h, --help Show this help message + -f, --force Force stop without confirmation + --all Stop all CK containers for this user + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-stop # Stop default container + ck-stop my_ck_container # Stop specific container + ck-stop --all # Stop all user's CK containers + ck-stop --force # Stop without confirmation + +EOF +} + +# Parse arguments +force=false +stop_all=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + -f|--force) + force=true + shift + ;; + --all) + stop_all=true + shift + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Function to stop a single container +stop_container() { + local name="$1" + + if ! container_exists "${name}"; then + echo "Container '${name}' does not exist" + return 1 + fi + + echo "Stopping and removing container '${name}'..." + docker stop "${name}" 2>/dev/null || true + docker rm "${name}" 2>/dev/null || true + echo "Container '${name}' stopped and removed" +} + +# Stop all user containers +if [ "$stop_all" = true ]; then + username=$(get_username) + containers=$(docker ps -a --filter "name=ck_${username}_" --format '{{.Names}}') + + if [ -z "$containers" ]; then + echo "No CK containers found for user '${username}'" + exit 0 + fi + + echo "Found CK containers for user '${username}':" + echo "$containers" + echo "" + + if [ "$force" = false ]; then + read -p "Stop and remove all these containers? (y/N) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Cancelled" + exit 0 + fi + fi + + echo "" + while IFS= read -r container; do + stop_container "$container" + done <<< "$containers" + + echo "" + echo "All containers stopped and removed" + exit 0 +fi + +# Stop single container +if ! container_exists "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' does not exist" + exit 0 +fi + +# Show container info +if container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' is currently running" +else + echo "Container '${CONTAINER_NAME}' exists but is stopped" +fi + +# Confirm if not forced +if [ "$force" = false ]; then + read -p "Stop and remove container '${CONTAINER_NAME}'? (y/N) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Cancelled" + exit 0 + fi +fi + +stop_container "${CONTAINER_NAME}" diff --git a/script/tools/ck-test b/script/tools/ck-test new file mode 100755 index 0000000000..712f904596 --- /dev/null +++ b/script/tools/ck-test @@ -0,0 +1,166 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Test - Build and test Composable Kernel in Docker + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Test - Build and test Composable Kernel in Docker + +Usage: ck-test [options] [test_options] + +Options: + -h, --help Show this help message + --name Specify container name + --reconfigure Reconfigure CMake before building + --no-build Skip building, run test directly + +Arguments: + test_name Name of test executable (required) + test_options Additional options passed to test (e.g., --gtest_filter=*) + +Environment: + CK_CONTAINER_NAME - Override default container name + GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + +Examples: + ck-test test_amdgcn_mma + ck-test test_amdgcn_mma --gtest_filter=*Fp16* + ck-test --name my_container test_amdgcn_mma + ck-test --reconfigure test_amdgcn_mma + +EOF +} + +# Parse arguments +test_name="" +reconfigure=false +no_build=false +test_options=() + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --reconfigure) + reconfigure=true + shift + ;; + --no-build) + no_build=true + shift + ;; + --gtest_*|--help) + test_options+=("$1") + shift + ;; + *) + if [ -z "$test_name" ]; then + test_name="$1" + else + test_options+=("$1") + fi + shift + ;; + esac +done + +# Validate test name +if [ -z "$test_name" ]; then + echo "Error: test_name required" + echo "" + show_help + exit 1 +fi + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Configure CMake if needed or requested +if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Detecting GPU target..." + GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") + + if [ "$reconfigure" = true ]; then + echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" + else + echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" + fi + + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace || exit 1 + rm -rf /workspace/build + mkdir /workspace/build + cd /workspace/build || exit 1 + cmake .. -GNinja \ + -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -DBUILD_TESTING=ON 2>&1 | tail -30 + " + echo "" +fi + +# Build test if needed (unless --no-build is specified) +if [ "$no_build" = false ]; then + if ! docker exec "${CONTAINER_NAME}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then + echo "Building ${test_name}..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${test_name} 2>&1 + " + echo "" + else + echo "Test executable found, rebuilding to ensure latest version..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${test_name} 2>&1 + " + echo "" + fi +fi + +# Run test +echo "Running: ${test_name} ${test_options[*]}" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +# Build the command with proper quoting +cmd="cd /workspace/build && ./bin/${test_name}" +for opt in "${test_options[@]}"; do + cmd="${cmd} $(printf '%q' "$opt")" +done + +docker exec "${CONTAINER_NAME}" bash -c "${cmd}" +exit_code=$? + +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +if [ $exit_code -eq 0 ]; then + echo "Test completed successfully" +else + echo "Test failed with exit code: ${exit_code}" +fi + +exit $exit_code From b8751e505d04cbb866bca769d408e9da8cb64c42 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 27 Jan 2026 00:57:42 +0530 Subject: [PATCH 26/42] feat: Add Interwave scheduler for aquant memory pipeline (#3540) * WIP: host level interwave pipeline compiles * WIP: interwave implementation computes correct GEMM result when no aquant * WIP: quantization works for subset of problem shapes * WIP: quantization works for subset of problem shapes * WIP: interwave memory pipeline passes local test * feat: Add interwave pipeline implementation for memory pipline in aquant * test: add unit test for aquant memory pipeline * WIP: host level interwave pipeline compiles * WIP: interwave implementation computes correct GEMM result when no aquant * WIP: quantization works for subset of problem shapes * WIP: quantization works for subset of problem shapes * WIP: interwave memory pipeline passes local test * feat: Add interwave pipeline implementation for memory pipline in aquant * fix: compilation error on gfx950 * chore: remove debug statements from the code * test: resolve merge conflict * test: remove non rcr unit tests from test suite --- .../gemm_aquant_quantgrouped.cpp | 2 +- .../38_block_scale_gemm/gemm_utils.hpp | 23 ++ .../run_gemm_quant_example.inc | 180 ++++++++++++- .../block_universal_gemm_as_aquant_bs_cr.hpp | 223 +++++++++++++++- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 2 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 30 ++- ...gemm_quant_aquant_mem_decode_interwave.cpp | 41 +++ ...gemm_quant_aquant_mem_decode_intrawave.cpp | 41 +++ ...emm_quant_aquant_mem_prefill_interwave.cpp | 41 +++ .../test_gemm_quant_aquant_prefill.cpp | 6 +- .../test_gemm_quant_fixtures.hpp | 249 ++++++++++++++++++ 11 files changed, 829 insertions(+), 9 deletions(-) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index ad1a4e0d10..e037be5a18 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantDecode; +using GemmConfig = GemmConfigQuantDecodeInterwave; // GemmConfigQuantPrefill is also supported for aquant grouped quantization // template diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index a95ca4862c..37117eaa0f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -93,6 +93,27 @@ struct GemmConfigQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); + + // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigQuantDecodeInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -229,6 +250,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); + + // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 912527c929..ed1709a9ae 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -650,7 +650,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else { ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(1.0f)}(*aq_tensor_ptr); ck_tile::FillConstant{static_cast(0x38)}(b_k_n); if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) @@ -659,6 +659,184 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } + else if(init_method == 3) + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::FillConstant{static_cast(0x38)}(a_m_k); + ck_tile::FillConstant{static_cast(0x22)}(b_k_n); + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + ck_tile::FillConstant{static_cast(0x38)}(a_m_k); + ck_tile::FillConstant{static_cast(0x22)}(b_k_n); + ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + else + { + ck_tile::FillConstant{static_cast(0x22)}(a_m_k); + ck_tile::FillConstant{static_cast(2.0f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + } + } + else if(init_method == 4) + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + ck_tile::FillUniformDistribution{2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + } + else if(init_method == 5) + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(a_m_k); + } + // Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...) + for(ck_tile::index_t row = 0; + row < static_cast(aq_tensor_ptr->get_length(0)); + ++row) + { + for(ck_tile::index_t col = 0; + col < static_cast(aq_tensor_ptr->get_length(1)); + ++col) + { + (*aq_tensor_ptr)(row, col) = static_cast(col + 1); + } + } + // std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl; + ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(b_k_n); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + } else { a_m_k.SetZero(); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 705a992b52..9d19e902e5 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -274,7 +274,9 @@ struct AQuantBlockUniversalGemmAsBsCr static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { CWarpTensor c_warp_tensor; + // for every column in AQ static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // for every warp corresponding to a quantization scale static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; @@ -322,6 +324,214 @@ struct AQuantBlockUniversalGemmAsBsCr } }; + template + struct BlockGemmImpl + { + static constexpr index_t KPerThread = GemmTraits::KPerThread; + static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; + + static constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; + + static constexpr auto ALdsTileDistr = + make_static_tile_distribution(MakeABlockDistributionEncode()); + static constexpr auto BLdsTileDistr = + make_static_tile_distribution(MakeBBlockDistributionEncode()); + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + BLdsTile b_warp_tile_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) + { + constexpr auto a_lds_load_distr = [&]() { + if constexpr(ALoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeABlockDistributionEncode()), + ADataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeABlockDistributionEncode()); + }(); + constexpr auto b_lds_load_distr = [&]() { + if constexpr(BLoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeBBlockDistributionEncode()), + BDataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeBBlockDistributionEncode()); + }(); + constexpr auto a_lds_shape = []() { + if constexpr(ALoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto b_lds_shape = []() { + if constexpr(BLoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto k_idx_offset = KIdx * KPerInnerLoop; + constexpr auto a_offset = + ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + constexpr auto b_offset = + BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + + auto a_lds_gemm_window = make_tile_window( + a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr); + auto b_lds_gemm_window = make_tile_window( + b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); + + load_int4_tile( + a_warp_tile_, a_lds_gemm_window); + load_int4_tile( + b_warp_tile_, b_lds_gemm_window); + } + + // C += A * B with quantization support + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) + { + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as corresponding " + "C block tensor data type!"); + constexpr auto warp_size = get_warp_size(); + + // Track which KRepeat chunk is currently loaded + index_t current_k_repeat_loaded = -1; + + // Restructured loop: M → N → QScale → KIterPerQScale + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Iterate over quantization groups + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + CWarpTensor c_warp_tensor; + + // Accumulate K iterations for this quantization group + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + // Map quantization indices to global K iteration + constexpr auto kIterGlobal = + kQScale * Traits::KIterPerQScale + kIterInQScale; + + // Map to KRepeat chunk and KInnerLoopIter offset + constexpr auto kRepeatIdx = kIterGlobal / KInnerLoopIter; + constexpr auto kInnerIdx = kIterGlobal % KInnerLoopIter; + + // Prefetch new chunk if needed + if constexpr(kInnerIdx == 0) + { + if(current_k_repeat_loaded != kRepeatIdx) + { + LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + __builtin_amdgcn_sched_barrier(0); + + if constexpr(kRepeatIdx != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + + current_k_repeat_loaded = kRepeatIdx; + } + } + + // Load A warp tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // Load B warp tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // Synchronization barrier at the end of last iteration + if constexpr(kQScale == Traits::QScalesPerBlockRow - 1 && + kIterInQScale == Traits::KIterPerQScale - 1 && + mIter.value == MIterPerWarp - 1 && + nIter.value == NIterPerWarp - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + // Accumulate: first iteration initializes, rest accumulate + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // Set priority for scheduling + if constexpr(kInnerIdx == 0 && mIter.value == 0 && nIter.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // Apply quantization scale after accumulating all K iterations for this + // group + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + AQPickerCommon aq_picker( + aq_block_tensor); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + }; + public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { @@ -329,7 +539,8 @@ struct AQuantBlockUniversalGemmAsBsCr MakeCBlockTile(); } - template @@ -338,7 +549,15 @@ struct AQuantBlockUniversalGemmAsBsCr bool_constant a_load_tr = {}, bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + if constexpr(Scheduler == GemmPipelineScheduler::Interwave) + { + block_gemm_impl_.template LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + } + else + { + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + } } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 650cd947f7..b87c12c14a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -499,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return PipelineImpl{} .template operator()( a_dram_block_window_tmp, - [](const OverrideADataType& a) { return a; }, + [](const BDataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 5749a8d3b2..30c4eb11f9 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -11,7 +11,24 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # Typed Test Suite for GEMM Quantization - split into multiple files to reduce compile time - # AQuant tests - split into 6 files + # AQuant tests - split into 10 files + + # AQuant Memory Pipeline tests + add_gtest_executable(test_tile_gemm_quant_aquant_mem_prefill_interwave + test_gemm_quant_aquant_mem_prefill_interwave.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_mem_prefill_interwave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_mem_decode_intrawave + test_gemm_quant_aquant_mem_decode_intrawave.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_mem_decode_intrawave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_mem_decode_interwave + test_gemm_quant_aquant_mem_decode_interwave.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_mem_decode_interwave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_aquant_base_rcr test_gemm_quant_aquant_base_rcr.cpp ) @@ -150,10 +167,21 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_tensor PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # Target to build only AQuant memory pipeline tests + add_custom_target(test_tile_gemm_aquant_mem_all) + add_dependencies(test_tile_gemm_aquant_mem_all + test_tile_gemm_quant_aquant_mem_prefill_interwave + test_tile_gemm_quant_aquant_mem_decode_intrawave + test_tile_gemm_quant_aquant_mem_decode_interwave + ) + # Umbrella target to build all gemm quant tests add_custom_target(test_tile_gemm_quant_all) add_dependencies(test_tile_gemm_quant_all # AQuant tests + test_tile_gemm_quant_aquant_mem_prefill_interwave + test_tile_gemm_quant_aquant_mem_decode_intrawave + test_tile_gemm_quant_aquant_mem_decode_interwave test_tile_gemm_quant_aquant_base_rcr test_tile_gemm_quant_aquant_base_rrr_crr test_tile_gemm_quant_aquant_base_ccr diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp new file mode 100644 index 0000000000..a7ab4120a1 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Mem Decode Interwave Configuration +// Tuple format: +// clang-format off +using AQuantMemDecodeInterwaveTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Mem Decode Interwave +TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemDecodeInterwaveTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemDecodeInterwaveTest) +{ + this->run_test_with_validation(16, 64, 512); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp new file mode 100644 index 0000000000..483138d711 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Mem Decode Intrawave Configuration +// Tuple format: +// clang-format off +using AQuantMemDecodeIntrawaveTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Mem Decode Intrawave +TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemDecodeIntrawaveTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemDecodeIntrawaveTest) +{ + this->run_test_with_validation(16, 64, 512); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp new file mode 100644 index 0000000000..7e851d9bd3 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Mem Prefill Interwave Configuration +// Tuple format: +// clang-format off +using AQuantMemPrefillInterwaveTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Mem Prefill Interwave +TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemPrefillInterwaveTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemPrefillInterwaveTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp index 133c11860a..911af678df 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp @@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantPrefillTypes = ::testing::Types< // RCR layout - with the Prefill BlockTile Config. - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 79c86935ef..9652dd449d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -69,6 +69,38 @@ struct GemmConfigPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +struct GemmConfigPrefillIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigPrefillInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +struct GemmConfigDecodeIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigDecodeInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + struct GemmConfigMxFp4 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; @@ -374,6 +406,223 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase +class TestCkTileGemmAQuantMem + : public TestCkTileGemmQuantBase> +{ + using Base = TestCkTileGemmQuantBase>; + friend Base; + + public: + using typename Base::AccDataType; + using typename Base::ADataType; + using typename Base::ALayout; + using typename Base::AQLayout; + using typename Base::BDataType; + using typename Base::BLayout; + using typename Base::CDataType; + using typename Base::CLayout; + using typename Base::ComputeDataType; + using typename Base::QDataType; + using typename Base::QuantGroupSize; + static constexpr auto QuantType = Base::QuantType; + + protected: + void SetUpQuantTypeSpecific() {} + void TearDownQuantTypeSpecific() {} + // AQuant-specific data generation + void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) + { + const ck_tile::index_t stride_A = + ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{})); + const ck_tile::index_t stride_B = + ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})); + const ck_tile::index_t stride_C = + ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{})); + // AQuant uses grouped quantization for A matrix + const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK); + // AQLayout is parameterized in the test tuple (can be RowMajor or ColumnMajor for AQuant) + const ck_tile::index_t stride_AQ = + ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(AQLayout{})); + // Generate test data + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); + // AQLayout is independently specified for each test case + ck_tile::HostTensor aq_m_aqk( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(AQLayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + // Initialize data with random values + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f}(a_m_k); + } + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f}(aq_m_aqk); + // Allocate device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); + ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() * sizeof(QDataType)); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); + ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType)); + // Copy to device + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor temp = a_m_k; + ck_tile::permute_vectors_i4x4_b(temp); + a_m_k_dev_buf.ToDevice(temp.data()); + } + else + { + a_m_k_dev_buf.ToDevice(a_m_k.data()); + } + // aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + if constexpr(Base::GemmConfig::PreshuffleQuant) + { + ck_tile::HostTensor aq_shuffle_host = + ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK); + aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data()); + } + else + { + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + } + b_k_n_dev_buf.ToDevice(b_k_n.data()); + // Create args for kernel execution + ck_tile::QuantGemmHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr + aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales) + nullptr, // bq_ptr (not used for AQuant) + 1, // k_batch + M, + N, + K, // M, N, K + AQK, // QK_A + 0, // QK_B (not used for AQuant) + stride_A, + stride_B, + stride_C, + stride_AQ, + 0 // strides + }; + // Run the kernel + ck_tile::stream_config stream_config{}; + this->invoke_quant_gemm(args, stream_config); + // Validation using reference implementation + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + // Run reference AQuant implementation + ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + // Get device result + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data()); + // Calculate error tolerances + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + this->template calculate_rtol_atol( + K, 1, max_accumulated_value); + // Validate results + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + EXPECT_TRUE(pass) << "AQuantGrouped validation failed with M=" << M << ", N=" << N + << ", K=" << K; + if(!pass) + { + std::cout << "AQuantGrouped - Relative error threshold: " + << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + } + + private: + // AQuant-specific pipeline implementation + template + void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args, + const ck_tile::stream_config& s) + { + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; + const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr bool transpose_c = CodegenGemmTraits::TransposeC; + using PipelineProblem = ck_tile::GemmAQuantPipelineProblem; + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrMem; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + transpose_c>>; + using Kernel = ck_tile::QuantGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Arguments not supported for AQuant kernel"); + } + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } +}; + // BQuant-specific test fixture template class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase> From 8942a19d5efafa151e0f894599bc625117d7aa76 Mon Sep 17 00:00:00 2001 From: yinglu Date: Tue, 27 Jan 2026 03:38:45 +0800 Subject: [PATCH 27/42] ck: add CK_USE_GFX950 macro (#3636) --- CMakeLists.txt | 5 +++++ include/ck/config.h.in | 7 ------- .../device_grouped_conv_bwd_data_xdl_instance.hpp | 2 +- .../device_grouped_conv_fwd_xdl_merged_groups_instance.hpp | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f1bdf8689..356491d9c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -259,6 +259,11 @@ if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx950" AND NOT FORCE_DISABLE_XDL) + message(STATUS "Enabling XDL FP8 gemms on gfx950") + add_definitions(-DCK_USE_GFX950) + set(CK_USE_GFX950 "ON") +endif() # new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA set(CK_TILE_USE_WMMA 0) diff --git a/include/ck/config.h.in b/include/ck/config.h.in index f5421e7d5e..306a6c2ff1 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -55,9 +55,6 @@ #ifndef CK_ENABLE_FP32 #define CK_ENABLE_FP32 "ON" #endif -#ifndef CK_ENABLE_TF32 -#define CK_ENABLE_TF32 "ON" -#endif #ifndef CK_ENABLE_FP64 #define CK_ENABLE_FP64 "ON" #endif @@ -88,10 +85,6 @@ #cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ #endif -#ifndef CK_ENABLE_TF32 -#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@ -#endif - #ifndef CK_ENABLE_FP64 #cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 745f8cbd32..970bcb0439 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -376,7 +376,7 @@ using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances = // clang-format on >; -#if defined(__gfx950__) +#if defined(CK_USE_GFX950) constexpr auto _k_per_block = 32; #else constexpr auto _k_per_block = 16; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 18abcb1613..3b7ce0df3a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -147,7 +147,7 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< // clang-format on >; -#if defined(__gfx950__) +#if defined(CK_USE_GFX950) constexpr auto _k_per_block = 32; #else constexpr auto _k_per_block = 16; From bd5fec81afdb6df7f4637128a3ba86dbfd6bcca1 Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Mon, 26 Jan 2026 13:56:06 -0600 Subject: [PATCH 28/42] Removing [4,64,16] warp tile from Tile Engine (#3643) --- tile_engine/ops/gemm/gemm_validation_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tile_engine/ops/gemm/gemm_validation_utils.py b/tile_engine/ops/gemm/gemm_validation_utils.py index cae6123307..1af45f8e90 100644 --- a/tile_engine/ops/gemm/gemm_validation_utils.py +++ b/tile_engine/ops/gemm/gemm_validation_utils.py @@ -128,7 +128,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -136,7 +135,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], @@ -148,7 +146,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -156,7 +153,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], @@ -169,7 +165,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -177,7 +172,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [ From 2e49b6b2f79d5ab0fe2fca79812affd44de94db7 Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Mon, 26 Jan 2026 21:57:09 +0100 Subject: [PATCH 29/42] Padding support for wave transfer (#3537) * Add padding support with transpose Also move check before writing storing is_src_valid during reading * Add/modify instances to use wave transfer for gemm universal Condition is changed so now the vectorsize of vmem reading and lds writing must be equal to 8 in order to use the wave transfer * Fix clang format * Modify example * Fix bwd data * Add restriction for wave transfer with padding and transpose Add test case which shows this limitation * Fix validity checks 8 bit types * Add validity check gemm_bias_add_reduce * Add validity check grouped gemm tile loop * Fix validity checks new flavours * Minor fixes * Fix clang format --- example/01_gemm/gemm_wmma_fp16_v3.cpp | 10 +-- ...ead_group_tensor_slice_transfer_global.hpp | 69 +++++++++++++--- ...ontraction_multiple_d_wmma_cshuffle_v3.hpp | 20 +++++ ...tched_gemm_multiple_d_wmma_cshuffle_v3.hpp | 20 +++++ ...e_batched_gemm_reduce_wmma_cshuffle_v3.hpp | 22 ++++++ ...e_batched_gemm_wmma_cshuffle_v3_common.hpp | 20 +++++ ..._gemm_bias_add_reduce_wmma_cshuffle_v3.hpp | 22 ++++++ ..._multiple_d_layernorm_wmma_cshuffle_v3.hpp | 22 ++++++ .../device_gemm_reduce_wmma_cshuffle_v3.hpp | 22 ++++++ .../device_gemm_wmma_cshuffle_v3_common.hpp | 20 +++++ .../impl/device_gemm_wmma_cshuffle_v3r1.hpp | 20 +++++ ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 8 +- ..._multiple_d_wmma_cshuffle_tile_loop_v3.hpp | 23 ++++++ ...e_grouped_gemm_wmma_splitk_cshuffle_v3.hpp | 23 +++++- .../grid/gridwise_ab_transfer_wave_tiles.hpp | 4 - .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 79 ++++++++++++++++--- ...wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp | 3 +- ...wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp | 5 +- ...wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 5 +- ...mm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 2 +- ...mm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 4 +- .../test_gemm_universal_ut_cases_fp16.inc | 8 +- 23 files changed, 385 insertions(+), 50 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp index 5b10edd681..3b3b0fec16 100644 --- a/example/01_gemm/gemm_wmma_fp16_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -19,22 +19,22 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, - PassThrough, PassThrough, PassThrough, GemmDefault, + PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, + 1, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp index 701c786c86..1c322fe4a7 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp @@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal // check if src element is valid const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + oob_thread_scratch_.template SetAsType(vgpr_data_idx_seq, is_src_valid); // Vector length of elementwise operation constexpr auto get_elem_op_vec_len = []() { @@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; - using vector_t = typename vector_type_maker::type::type; - dst_vector_type op_r_v; // Load data from memory in src_vector first - src_vector_container src_vector = - src_vector_container{grid_buf.template Get( - src_coord_.GetOffset(), true)}; + auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0; + src_vector_container src_vector = src_vector_container{ + grid_buf.template Get(index, true)}; // apply the src elementwise op and convert to DstData under the hood if needed static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) { @@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal // store result in dvgpr_ (static array holding loaded data). // At this point data is already converted to DstData type and // the elementwise operation has been applied - dvgpr_.template SetAsType( - vgpr_data_idx_seq, - is_src_valid ? op_r_v.template AsType()[I0] : vector_t(0)); + src_dvgpr_.template SetAsType(vgpr_data_idx_seq, + op_r_v.template AsType()[I0]); // For each dimension move fwd, bwd or don't move static_for<0, nDim, 1>{}([&](auto i) { @@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal container_reorder_given_new2old(src_access_lengths, src_dim_access_order); constexpr auto ordered_fwd_step = StepsPerIteration{}; + // OOB check + static_ford{}([&](auto ordered_src_access_idx) { + // calculate src data index and make sequence + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}( + [&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order); + }(); + + // make sequence to access vgpr data. Add zero as last element of src_data_idx_seq + constexpr auto vgpr_data_idx_seq = generate_sequence_v2( + [&](auto i) { + if constexpr(i.value < src_data_idx.Size()) + { + return Number{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + auto op_r = src_dvgpr_.template GetAsType(vgpr_data_idx_seq); + const bool is_src_valid = + oob_thread_scratch_.template GetAsType(vgpr_data_idx_seq); + auto op_r_v = is_src_valid ? op_r : dst_vector_t(0); + dst_dvgpr_.template SetAsType(vgpr_data_idx_seq, op_r_v); + }); + // make forward steps // forward step for each iteration just add 1 const auto dst_forward_steps = generate_tuple( @@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal dst_buf.template Set( dst_coord_.GetOffset(), true, - dvgpr_.template GetAsType(vgpr_data_idx_seq)); + dst_dvgpr_.template GetAsType(vgpr_data_idx_seq)); // For each dimension move fwd, bwd or don't move static_for<0, nDim, 1>{}([&](auto i) { @@ -389,6 +420,14 @@ struct ThreadGroupTransferGlobal return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); } + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto access_lengths_as_tuple = + container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{}); + + return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); + } + static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){}; using ThreadScratchData = StaticTensorTupleOfVectorBuffer; - ThreadScratchData dvgpr_; + static constexpr auto src_oob_thread_scratch_desc_ = + decltype(GetSrcThreadScratchDescriptor()){}; + using OOBThreadScratch = StaticTensorTupleOfVectorBuffer; + + ThreadScratchData src_dvgpr_; + ThreadScratchData dst_dvgpr_; + OOBThreadScratch oob_thread_scratch_; SrcCoord src_coord_; DstCoord dst_coord_; const ElementwiseOperation element_op_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp index 47ef2e339d..b59357ffe9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp @@ -833,6 +833,26 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3 return false; } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + // check vector access static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp index 126d107725..ae247f4e31 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -606,6 +606,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp index 227a8aedd9..593a908498 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp @@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, std::array{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp index 59a820861c..fb1ca3127e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -455,6 +455,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common return false; } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp index e8e3b69cb5..85ca16b293 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{ std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp index f0216c3f71..81f505b594 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemmWelford::Argument gemm_arg{ std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 317c4073df..28c9f2bddc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, std::array{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index e96ec58cba..c09befa717 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common } } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp index e09c69d052..377f792979 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1(&arg)); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index bbf62d5fbe..dfdfd53725 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -450,8 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 BlkGemmPipelineVer, AComputeType, BComputeType, - false, - false>; + false, // PermuteA + false, // PermuteB + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer #define GridwiseGemmCTransposeTemplateParameters \ ALayout, BLayout, DsLayout, ELayout, Tuple, Tuple, AccDataType, \ @@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \ CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \ - AComputeType, false, false + AComputeType, false, false, false, true using GridwiseGemmCTranspose = std::conditional_t placeholder_p_ds_grid{}; std::array stride_Ds; std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 39024d39e4..99a18e07fc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -704,7 +704,28 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK::selected_wmma .wave_size; + __host__ __device__ static constexpr bool AWaveTransferApplicable() + { + return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && + ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && + !IsBPreShuffled; + } + + __host__ __device__ static constexpr bool BWaveTransferApplicable() + { + return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && + BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + } + // Limitations of the current implementation: // - no multiAB - // - GemmSpecialization Default with transpose #ifdef __gfx12__ - static constexpr bool IsAWaveTransferApplicable = - !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && - ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && - !is_same_v) || - is_same_v) && - BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; + static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable(); - static constexpr bool IsBWaveTransferApplicable = - !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && - ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && - !is_same_v) || - is_same_v) && - BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable(); static constexpr bool IsWaveTileInterleavedFitting = (NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize); @@ -982,6 +986,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return de_grid_desc_mblock_mperblock_nblock_nperblock; } + // Conditions for Wave Transfer with transpose: + // - 16 bit type: K % 8 == 0 (4 subtiles of 8x8) + // - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16) + __host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K) + { + if constexpr(AWaveTransferApplicable() && + !(is_same::value)) + { + if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0)) + { + return false; + } + bool pass = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + pass &= !(sizeof(ADataType_) == 1 && + !(M % (2 * ABlockTransferSrcScalarPerVector) == 0)); + }); + return pass; + } + else + { + return true; + } + } + + __host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K) + { + if constexpr(BWaveTransferApplicable() && + !(is_same::value)) + { + if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0)) + { + return false; + } + bool pass = true; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + pass &= !(sizeof(BDataType_) == 1 && + !(N % (2 * BBlockTransferSrcScalarPerVector) == 0)); + }); + return pass; + } + else + { + return true; + } + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ static constexpr bool CheckValidity(const Argument& karg, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp index d79fe9bfa3..d7b654a345 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -47,7 +47,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp index e284cbbb83..7d7966c47f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -40,7 +40,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 6195d40f87..2f63199480 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -41,7 +41,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -52,7 +52,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index e51bec3dfb..b50e37cf0a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -44,7 +44,7 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index 66ba1e3830..4651068d86 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -40,9 +40,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 8eccccf354..4dcbaccaa4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -41,7 +41,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -49,7 +49,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc index 25d95cda3d..01d7d5a5fd 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc @@ -125,7 +125,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_NK, MidLargeM) TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -139,7 +139,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -153,7 +153,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -169,7 +169,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK) TYPED_TEST(TestGemmUniversal_FP16_KM_NK, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; From a213ce676bb6b72e177f73befa4d56b0ce60fbec Mon Sep 17 00:00:00 2001 From: John Shumway Date: Mon, 26 Jan 2026 13:44:36 -0800 Subject: [PATCH 30/42] Add python analysis scripts for Clang's time trace (#3644) This PR introduces a Python toolkit for analyzing Clang's `-ftime-trace` build performance data. This is the foundation for our systematic effort to reduce CK and CK-Tile build times (#3575). The toolkit provides fast parsing of trace JSON files into pandas DataFrames using orjson, with specialized functions for analyzing template instantiation costs and compilation phase breakdowns. It includes a core library (`trace_analysis/`), example scripts for quick analysis, a comprehensive README with usage documentation, and an interactive Jupyter notebook demonstration. Key features include memory-efficient DataFrame schemas with optimized dtypes, recursive hierarchical phase analysis, automatic metadata extraction (source file, compilation timing), and template instantiation filtering. The design supports both standalone scripts and interactive Jupyter notebook workflows. This single-file analysis capability lays the groundwork for future multi-file analysis across thousands of compilation units, enabling data-driven optimization and build time regression detection. --- script/analyze_build/README.md | 263 +++++++++++++ .../notebooks/file_analysis_example.ipynb | 247 ++++++++++++ script/analyze_build/requirements.txt | 18 + .../analyze_build/trace_analysis/__init__.py | 34 ++ .../trace_analysis/parse_file.py | 356 ++++++++++++++++++ .../trace_analysis/phase_breakdown.py | 354 +++++++++++++++++ .../trace_analysis/template_analysis.py | 80 ++++ .../trace_analysis/template_parser.py | 301 +++++++++++++++ 8 files changed, 1653 insertions(+) create mode 100644 script/analyze_build/README.md create mode 100644 script/analyze_build/notebooks/file_analysis_example.ipynb create mode 100644 script/analyze_build/requirements.txt create mode 100644 script/analyze_build/trace_analysis/__init__.py create mode 100644 script/analyze_build/trace_analysis/parse_file.py create mode 100644 script/analyze_build/trace_analysis/phase_breakdown.py create mode 100644 script/analyze_build/trace_analysis/template_analysis.py create mode 100644 script/analyze_build/trace_analysis/template_parser.py diff --git a/script/analyze_build/README.md b/script/analyze_build/README.md new file mode 100644 index 0000000000..7a88b98e77 --- /dev/null +++ b/script/analyze_build/README.md @@ -0,0 +1,263 @@ +# Build Trace Analysis + +Simple to use, fast python tools for analyzing Clang `-ftime-trace` build performance data. + +## Overview + +We're kicking off a systematic effort to dramatically reduce CK and CK-Tile build times, [#3575](https://github.com/ROCm/composable_kernel/issues/3575). A key part of this work is improving our C++ metaprogramming to reduce the burden on the compiler. + +In order to prioritize work and measure our progress, we need data on template instantiation. For single files, Clang's `-ftime-trace` build performance data is easy to analyze with the Perfetto UI. The problem we are solving here is how to analyze instantiation data across thousands of compilation units. + +The python code in this directory provides helper functions to quickly load JSON files into pandas DataFrames that can be used for analysis in Jupyter notebooks. + +## Directory Structure + +``` +script/analyze_build/ +├── trace_analysis/ # Core library +│ ├── __init__.py # Main exports +│ ├── parse_file.py # Fast parsing of JSON trace files +│ ├── template_analysis.py # Template instantiation analysis +│ ├── template_parser.py # Template name parsing utilities +│ └── phase_breakdown.py # Compilation phase breakdown +├── notebooks/ # Jupyter notebooks for analysis +│ └── file_analysis_example.ipynb # Template analysis example +├── requirements.txt # Python dependencies +└── README.md # This file +``` + +## Python Requirements + +See `requirements.txt` for the complete list of dependencies: +* **pandas** - DataFrame manipulation and analysis +* **orjson** - Fast JSON parsing for trace files +* **plotly** - Interactive visualizations (sunburst, treemap) +* **nbformat** - Jupyter notebook format support +* **ipykernel** - Kernel for running notebooks in VSCode/Jupyter +* **kaleido** - Static image export from Plotly charts +* **jupyter** - Full Jupyter environment + +## Quick Start + +### Setup + +1. Create a virtual environment (recommended): +```bash +cd script/analyze_build +python3 -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate +``` + +2. Install dependencies: +```bash +pip install -r requirements.txt +``` + +3. Install VSCode extensions if you want to run notebooks in VSCode: + * Jupyter + * Data Wrangler (interact with Pandas DataFrames) + +### Analyzing a Single File + +Use the `parse_file` function to load a `-ftime-trace` JSON file into a Pandas DataFrame: + +```python +from trace_analysis import parse_file + +# Parse the trace file +df = parse_file('path/to/trace.json') + +# View basic info +print(f"Total events: {len(df)}") +print(df.columns) + +# Analyze duration statistics +print(df['dur'].describe()) +``` + +### Extracting Compilation Metadata + +Get high-level metadata about the compilation: + +```python +from trace_analysis import get_metadata + +# Extract metadata from trace file +metadata = get_metadata('trace.json') + +print(f"Source file: {metadata['source_file']}") +print(f"Compilation time: {metadata['total_wall_time_s']:.2f}s") +print(f"Started: {metadata['wall_start_datetime']}") +print(f"Ended: {metadata['wall_end_datetime']}") +``` + +The metadata includes: +- `source_file`: Main .cpp/.c file being compiled +- `time_granularity`: Time unit used ("microseconds") +- `beginning_of_time`: Epoch timestamp in microseconds +- `wall_start_time`: Wall clock start (microseconds since epoch) +- `wall_end_time`: Wall clock end (microseconds since epoch) +- `wall_start_datetime`: Human-readable start time +- `wall_end_datetime`: Human-readable end time +- `total_wall_time_us`: Total compilation time in microseconds +- `total_wall_time_s`: Total compilation time in seconds + +### Template Instantiation Analysis + +The module includes specialized functions for analyzing C++ template instantiation costs: + +```python +from trace_analysis import ( + parse_file, + get_template_instantiation_events, + get_phase_breakdown, +) + +df = parse_file('trace.json') + +# Get all template instantiation events with parsed template information +template_events = get_template_instantiation_events(df) + +# The returned DataFrame includes parsed columns: +# - namespace: Top-level namespace (e.g., 'std', 'ck') +# - template_name: Template name without parameters +# - full_qualified_name: Full namespace::template_name +# - param_count: Number of template parameters +# - is_ck_type: Boolean indicating CK library types +# - is_nested: Boolean indicating nested templates + +# Find slowest template instantiations +top_templates = template_events.nlargest(20, 'dur') +print(top_templates[['template_name', 'namespace', 'param_count', 'dur']]) + +# Analyze by namespace +namespace_summary = template_events.groupby('namespace').agg({ + 'dur': ['count', 'sum', 'mean'] +}) +print(namespace_summary) +``` + +### Compilation Phase Breakdown + +Analyze how compilation time is distributed across different phases: + +```python +from trace_analysis import get_phase_breakdown, PhaseBreakdown + +df = parse_file('trace.json') + +# Get hierarchical phase breakdown +breakdown = get_phase_breakdown(df) + +# Display in Jupyter (automatic rich HTML display) +display(breakdown) + +# Print text representation +print(breakdown) + +# Access the underlying DataFrame +print(breakdown.df) + +# Convert to plotly format for visualization +import plotly.express as px +data = breakdown.to_plotly() +fig = px.sunburst(**data) +fig.show() +``` + +The `PhaseBreakdown` class provides: +- Hierarchical breakdown of compilation phases +- Automatic calculation of "Other" residual time at each level +- Validation that children don't exceed parent durations +- Multiple output formats (text, DataFrame, Plotly) + +## DataFrame Schema + +The parsed DataFrame contains the following columns from the `-ftime-trace` format: + +- `name`: Event name (function, template instantiation, etc.) +- `ph`: Phase character ('X' for complete, 'B' for begin, 'E' for end, 'i' for instant) +- `ts`: Timestamp in microseconds +- `dur`: Duration in microseconds (for complete events) +- `pid`: Process ID +- `tid`: Thread ID +- `arg_*`: Flattened arguments from the event's `args` field + +### Template Event Columns + +When using `get_template_instantiation_events()`, additional parsed columns are included: + +- `namespace`: Top-level namespace extracted from the template name +- `template_name`: Template name without namespace or parameters +- `full_qualified_name`: Complete namespace::template_name +- `param_count`: Number of template parameters +- `is_ck_type`: Boolean flag for CK library types (namespace starts with 'ck') +- `is_nested`: Boolean flag indicating nested template instantiations + +## Use in Jupyter Notebooks + +The module is designed to work seamlessly in Jupyter notebooks. See `notebooks/file_analysis_example.ipynb` for a complete example workflow that demonstrates: + +- Loading and parsing trace files +- Extracting compilation metadata +- Analyzing phase breakdown with visualizations +- Template instantiation analysis with parsed columns +- Filtering and grouping by namespace +- Identifying CK-specific template costs + +To use in a notebook: + +```python +import sys +from pathlib import Path + +# Add trace_analysis to path +sys.path.insert(0, str(Path.cwd().parent)) + +from trace_analysis import ( + parse_file, + get_metadata, + get_template_instantiation_events, + get_phase_breakdown, +) + +# Load and analyze +df = parse_file('path/to/trace.json') +breakdown = get_phase_breakdown(df) +templates = get_template_instantiation_events(df) + +# Visualize +import plotly.express as px +fig = px.sunburst(**breakdown.to_plotly()) +fig.show() +``` + +## API Reference + +### Core Functions + +- `parse_file(filepath)`: Parse a `-ftime-trace` JSON file into a pandas DataFrame +- `get_metadata(filepath_or_df)`: Extract compilation metadata from trace file or DataFrame + +### Template Analysis + +- `get_template_instantiation_events(df)`: Filter to template instantiation events with parsed template information + +### Phase Breakdown + +- `get_phase_breakdown(df)`: Generate hierarchical compilation phase breakdown +- `PhaseBreakdown`: Class representing phase breakdown with multiple output formats + +## Contributing + +This is an experimental project for analyzing and improving C++ metaprogramming build times. Contributions are welcome! When adding new analysis functions: + +1. Add the function to the appropriate module in `trace_analysis/` +2. Export it in `__init__.py` +3. Update this README with usage examples +4. Consider adding a notebook example if the feature is substantial + +## License + +Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +SPDX-License-Identifier: MIT diff --git a/script/analyze_build/notebooks/file_analysis_example.ipynb b/script/analyze_build/notebooks/file_analysis_example.ipynb new file mode 100644 index 0000000000..e8d1ee3bcd --- /dev/null +++ b/script/analyze_build/notebooks/file_analysis_example.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Template Instantiation Analysis Example\n", + "\n", + "This notebook demonstrates how to use the template analysis functions to understand C++ template instantiation costs in Clang's `-ftime-trace` output.\n", + "\n", + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Add parent directory to path\n", + "sys.path.insert(0, str(Path.cwd().parent))\n", + "\n", + "from trace_analysis import (\n", + " parse_file,\n", + " get_template_instantiation_events,\n", + " get_phase_breakdown,\n", + " get_metadata,\n", + ")\n", + "\n", + "import pandas as pd\n", + "from datetime import datetime\n", + "import plotly.express as px\n", + "\n", + "\n", + "# Display settings\n", + "pd.set_option(\"display.max_rows\", 100)\n", + "pd.set_option(\"display.max_columns\", None)\n", + "pd.set_option(\"display.width\", None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Trace File" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load your trace file\n", + "trace_file = Path(\n", + " \"../../../build-trace/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeFiles/device_conv2d_fwd_instance.dir/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp.json\"\n", + ")\n", + "df = parse_file(trace_file)\n", + "\n", + "print(f\"Total events: {len(df):,}\")\n", + "starting_timestamp = datetime.fromtimestamp(df.attrs[\"beginningOfTime\"] / 1e6)\n", + "print(f\"Starting timestamp: {starting_timestamp.strftime('%Y-%m-%d:%H:%M:%S')}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_metadata(df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compilation Overview" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get phase breakdown and display it\n", + "breakdown = get_phase_breakdown(df)\n", + "print(breakdown)\n", + "display(breakdown)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Extract data for plotly charts (sunburst, tree-map, or icicle)\n", + "plotly_data = breakdown.to_plotly()\n", + "fig = px.sunburst(**plotly_data)\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Template Instantiation Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get all template instantiation events (now with parsed columns!)\n", + "template_events = get_template_instantiation_events(df)\n", + "\n", + "print(f\"Total template instantiation events: {len(template_events):,}\")\n", + "print(f\"Total template time: {template_events['dur'].sum() / 1000:.1f} ms\")\n", + "display(template_events)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Examine Parsed Columns\n", + "\n", + "The `get_template_instantiation_events()` function automatically parses the `arg_detail` column into structured fields:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show the new parsed columns\n", + "print(\"Parsed columns available:\")\n", + "print(\"- namespace: Top-level namespace (e.g., 'std', 'ck')\")\n", + "print(\"- template_name: Template name without parameters\")\n", + "print(\"- full_qualified_name: Full namespace::template_name\")\n", + "print(\"- param_count: Number of template parameters\")\n", + "print(\"- is_ck_type: Boolean indicating CK library types\")\n", + "print(\"- is_nested: Boolean indicating nested templates\")\n", + "print()\n", + "\n", + "# Display sample of parsed data\n", + "template_events[\n", + " [\"namespace\", \"template_name\", \"param_count\", \"is_ck_type\", \"is_nested\", \"dur\"]\n", + "].head(20)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Analysis by Namespace" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Group by namespace to see where time is spent\n", + "namespace_summary = (\n", + " template_events.groupby(\"namespace\")\n", + " .agg({\"dur\": [\"count\", \"sum\", \"mean\"], \"param_count\": \"mean\"})\n", + " .round(2)\n", + ")\n", + "\n", + "namespace_summary.columns = [\"count\", \"total_dur\", \"avg_dur\", \"avg_params\"]\n", + "namespace_summary[\"total_ms\"] = namespace_summary[\"total_dur\"] / 1000\n", + "namespace_summary = namespace_summary.sort_values(\"total_dur\", ascending=False)\n", + "\n", + "print(\"\\nTemplate Instantiation Time by Namespace:\")\n", + "display(namespace_summary)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CK Library Templates Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Filter to CK types only\n", + "ck_templates = template_events[template_events[\"is_ck_type\"]].copy()\n", + "\n", + "print(f\"CK template instantiations: {len(ck_templates):,}\")\n", + "print(f\"CK template time: {ck_templates['dur'].sum() / 1000:.1f} ms\")\n", + "print(\n", + " f\"Percentage of total template time: {100 * ck_templates['dur'].sum() / template_events['dur'].sum():.1f}%\"\n", + ")\n", + "print()\n", + "\n", + "# Top CK templates by time\n", + "ck_by_name = (\n", + " ck_templates.groupby(\"template_name\")\n", + " .agg({\"dur\": [\"count\", \"sum\", \"mean\"]})\n", + " .round(2)\n", + ")\n", + "ck_by_name.columns = [\"count\", \"total_dur\", \"avg_dur\"]\n", + "ck_by_name[\"total_ms\"] = ck_by_name[\"total_dur\"] / 1000\n", + "ck_by_name = ck_by_name.sort_values(\"total_dur\", ascending=False)\n", + "\n", + "print(\"\\nTop CK Templates by Total Time:\")\n", + "display(ck_by_name.head(20))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/script/analyze_build/requirements.txt b/script/analyze_build/requirements.txt new file mode 100644 index 0000000000..fd99fdba09 --- /dev/null +++ b/script/analyze_build/requirements.txt @@ -0,0 +1,18 @@ +# Build Trace Analysis - Python Dependencies + +# Core data processing +pandas>=2.0.0 +orjson>=3.9.0 + +# Jupyter notebook support +nbformat>=4.2.0 +ipykernel>=6.0.0 + +# Interactive visualizations +plotly>=5.0.0 + +# Static image export from Plotly +kaleido>=0.2.0 + +# Full Jupyter environment (if not using VSCode) +jupyter>=1.0.0 diff --git a/script/analyze_build/trace_analysis/__init__.py b/script/analyze_build/trace_analysis/__init__.py new file mode 100644 index 0000000000..70db321083 --- /dev/null +++ b/script/analyze_build/trace_analysis/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Build Trace Analysis - Core library for analyzing Clang -ftime-trace data. + +This package provides tools to parse and analyze Clang's -ftime-trace JSON output +for build performance analysis. +""" + +from .parse_file import ( + parse_file, + get_metadata, +) + +from .template_analysis import ( + get_template_instantiation_events, +) + +from .phase_breakdown import ( + get_phase_breakdown, + PhaseBreakdown, +) + +__all__ = [ + # Core parsing and filtering + "parse_file", + "get_metadata", + # Template analysis + "get_template_instantiation_events", + # Phase breakdown + "get_phase_breakdown", + "PhaseBreakdown", +] diff --git a/script/analyze_build/trace_analysis/parse_file.py b/script/analyze_build/trace_analysis/parse_file.py new file mode 100644 index 0000000000..24d71e4eb8 --- /dev/null +++ b/script/analyze_build/trace_analysis/parse_file.py @@ -0,0 +1,356 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Parse a single Clang -ftime-trace JSON file into a Pandas DataFrame. + +This module provides fast parsing of Clang's -ftime-trace output using orjson +for performance. The JSON file is typically a single-line array of trace events. +""" + +import orjson +import pandas as pd +from pathlib import Path +from typing import Union, Optional +from datetime import datetime +from dataclasses import dataclass + + +# Expected schema for trace event DataFrames with optimized dtypes +# This enforces strict column validation and memory-efficient types +# The memory usage is dominated by arg detail, but we optimize each series. +TRACE_EVENT_DTYPES = { + "pid": "int32", # Process ID (max observed: ~2.3M, fits in int32) + "tid": "int32", # Thread ID (max observed: ~2.3M, fits in int32) + "ts": "int64", # Timestamp in microseconds (requires int64 for epoch times) + "cat": "category", # Category (low cardinality, use categorical) + "ph": "category", # Phase type (very low cardinality: X, B, E, i, etc.) + "id": "int64", # Event ID + "name": "category", # Event name (medium cardinality, use categorical) + "dur": "int64", # Duration in microseconds (max 10 days = 864B μs, requires int64) + "arg_detail": "string", # Detail string (high cardinality, keep as string) + "arg_count": "int64", # Argument count + "arg_avg ms": "int64", # Average milliseconds + "arg_name": "category", # Argument name (medium cardinality, use categorical) +} + + +@dataclass +class FileMetadata: + """ + Processed metadata with computed fields for compilation analysis. + + This extends the raw metadata with derived values like formatted timestamps + and converted time units for convenience. + + Attributes: + source_file: Main .cpp/.c file being compiled + time_granularity: Time unit used in trace (always "microseconds" for Clang) + beginning_of_time: Epoch timestamp in microseconds from JSON root + execute_compiler_ts: Timestamp of ExecuteCompiler event (microseconds) + execute_compiler_dur: Duration of ExecuteCompiler event (microseconds) + total_wall_time_us: Total compilation time in microseconds (same as execute_compiler_dur) + total_wall_time_s: Total compilation time in seconds (computed from microseconds) + wall_start_time: Wall clock start time in microseconds since epoch (computed) + wall_end_time: Wall clock end time in microseconds since epoch (computed) + wall_start_datetime: Human-readable start time string (formatted) + wall_end_datetime: Human-readable end time string (formatted) + """ + + source_file: Optional[str] = None + time_granularity: str = "microseconds" + beginning_of_time: Optional[int] = None + execute_compiler_ts: Optional[int] = None + execute_compiler_dur: Optional[int] = None + total_wall_time_us: Optional[int] = None + total_wall_time_s: Optional[float] = None + wall_start_time: Optional[int] = None + wall_end_time: Optional[int] = None + wall_start_datetime: Optional[str] = None + wall_end_datetime: Optional[str] = None + + def __repr__(self): + # auto-generate pretty lines + fields = "\n".join( + f" {name} = {value!r}" for name, value in self.__dict__.items() + ) + return f"{self.__class__.__name__}(\n{fields}\n)" + + +def parse_file(filepath: Union[str, Path]) -> pd.DataFrame: + """ + Parse a Clang -ftime-trace JSON file into a Pandas DataFrame. + + The -ftime-trace format is a JSON array of trace events. Each event contains + fields like name, phase (ph), timestamp (ts), duration (dur), process/thread IDs, + and optional arguments (args). + + The beginningOfTime value from the JSON structure is automatically extracted + and stored in df.attrs['beginningOfTime']. Use get_metadata(df) to get + processed metadata with event-derived fields and computed values. + + Args: + filepath: Path to the -ftime-trace JSON file + + Returns: + DataFrame with columns for each event field. Nested 'args' are flattened + with an 'arg_' prefix. The beginningOfTime value is stored in + df.attrs['beginningOfTime']. + + Raises: + FileNotFoundError: If the file doesn't exist + ValueError: If the JSON is invalid or empty + + Examples: + >>> df = parse_file('build/trace.json') + >>> df[['name', 'dur']].head() + >>> + >>> # Access processed metadata + >>> metadata = get_metadata(df) + >>> print(f"Compiled: {metadata.source_file}") + >>> print(f"Duration: {metadata.total_wall_time_s:.2f}s") + >>> + >>> # Access beginningOfTime directly if needed + >>> beginning = df.attrs.get('beginningOfTime') + >>> print(f"Beginning of time: {beginning}") + """ + filepath = Path(filepath) + + if not filepath.exists(): + raise FileNotFoundError(f"Trace file not found: {filepath}") + + # Read and parse JSON using orjson for speed + with open(filepath, "rb") as f: + data = orjson.loads(f.read()) + + if not data: + raise ValueError(f"Empty trace data in file: {filepath}") + + # Handle both formats: direct array or {"traceEvents": [...]} + if isinstance(data, dict): + if "traceEvents" in data: + events = data["traceEvents"] + else: + raise ValueError( + f"Expected 'traceEvents' key in JSON object, got keys: {list(data.keys())}" + ) + elif isinstance(data, list): + events = data + else: + raise ValueError(f"Expected JSON array or object, got {type(data).__name__}") + + # Convert to DataFrame + df = pd.DataFrame(events) + + if df.empty: + raise ValueError(f"No trace events found in file: {filepath}") + + # Flatten 'args' column if it exists + if "args" in df.columns: + df = _flatten_args(df) + + # Validate schema: check for missing columns + expected_columns = set(TRACE_EVENT_DTYPES.keys()) + actual_columns = set(df.columns) + + missing_columns = expected_columns - actual_columns + if missing_columns: + raise ValueError( + f"Missing expected columns in trace data: {sorted(missing_columns)}" + ) + + # Validate schema: check for unexpected columns + unexpected_columns = actual_columns - expected_columns + if unexpected_columns: + raise ValueError( + f"Unexpected columns found in trace data: {sorted(unexpected_columns)}" + ) + + # Apply optimized dtypes with strict type enforcement + for col, dtype in TRACE_EVENT_DTYPES.items(): + if dtype in ("int64", "int32"): + # Fill missing values with 0 for integer columns, then convert to specified int type + df[col] = df[col].fillna(0).astype(dtype) + elif dtype == "category": + # Convert to categorical for memory efficiency with repeated values + df[col] = df[col].astype("category") + elif dtype == "string": + # Convert to pandas string dtype for memory efficiency + df[col] = df[col].astype("string") + else: + raise ValueError(f"Unsupported dtype '{dtype}' for column '{col}'") + + # Extract and store beginningOfTime in DataFrame attributes + df.attrs["beginningOfTime"] = ( + data.get("beginningOfTime") if isinstance(data, dict) else None + ) + + return df + + +def _flatten_args(df: pd.DataFrame) -> pd.DataFrame: + """ + Flatten the 'args' column into separate columns with 'arg_' prefix. + + The 'args' field in trace events contains additional metadata as a dictionary. + This function extracts those key-value pairs into separate columns. + + Args: + df: DataFrame with an 'args' column containing dictionaries + + Returns: + DataFrame with flattened args columns and original 'args' column removed + """ + # Extract args into separate DataFrame + args_data = [] + for idx, row in df.iterrows(): + args = row.get("args", {}) + if isinstance(args, dict): + args_data.append(args) + else: + args_data.append({}) + + if args_data: + args_df = pd.DataFrame(args_data) + # Prefix all args columns with 'arg_' + args_df.columns = [f"arg_{col}" for col in args_df.columns] + + # Drop original args column and concatenate flattened args + df = df.drop(columns=["args"]) + df = pd.concat([df, args_df], axis=1) + + return df + + +def _normalize_source_path(file_path: str) -> str: + """ + Normalize a source file path to be relative to composable_kernel if present. + + If 'composable_kernel' appears in the path, returns the path starting from + 'composable_kernel/'. Otherwise, returns the original path unchanged. + + Args: + file_path: Full filesystem path to a source file + + Returns: + Normalized path starting from composable_kernel, or original path if + composable_kernel is not found + + Examples: + >>> _normalize_source_path('/home/user/composable_kernel/include/ck/tensor.hpp') + 'composable_kernel/include/ck/tensor.hpp' + >>> _normalize_source_path('/usr/include/vector') + '/usr/include/vector' + """ + path = Path(file_path) + parts = path.parts + + # Find the last occurrence of 'composable_kernel' in the path + for i in range(len(parts) - 1, -1, -1): + if parts[i] == "composable_kernel": + # Return path from composable_kernel onwards + return str(Path(*parts[i:])) + + # If composable_kernel not found, return original path + return file_path + + +def get_metadata(df: pd.DataFrame) -> FileMetadata: + """ + Extract and process compilation metadata from a DataFrame. + + This function processes events from the DataFrame to extract compilation + information, then computes derived fields like formatted timestamps and + converted time units. + + Args: + df: DataFrame returned by parse_file() with beginningOfTime in its .attrs + + Returns: + FileMetadata instance with both raw and computed fields: + - source_file: Main .cpp/.c file being compiled (from events) + - time_granularity: Time unit used in trace ("microseconds") + - beginning_of_time: Epoch timestamp in microseconds from JSON root + - execute_compiler_ts: Timestamp of ExecuteCompiler event (from events) + - execute_compiler_dur: Duration of ExecuteCompiler event (from events) + - total_wall_time_us: Total compilation time in microseconds + - total_wall_time_s: Total compilation time in seconds (computed) + - wall_start_time: Wall clock start time (computed) + - wall_end_time: Wall clock end time (computed) + - wall_start_datetime: Human-readable start time (formatted) + - wall_end_datetime: Human-readable end time (formatted) + + Examples: + >>> df = parse_file('trace.json') + >>> metadata = get_metadata(df) + >>> print(f"Compiled: {metadata.source_file}") + >>> print(f"Duration: {metadata.total_wall_time_s:.2f}s") + >>> print(f"Started: {metadata.wall_start_datetime}") + """ + # Extract beginningOfTime from DataFrame attributes + beginning_of_time = None + if hasattr(df, "attrs"): + beginning_of_time = df.attrs.get("beginningOfTime") + + # Initialize metadata with beginningOfTime from JSON structure + metadata = FileMetadata(beginning_of_time=beginning_of_time) + + # Process events to extract ExecuteCompiler timing information + if "name" in df.columns: + execute_compiler = df[df["name"] == "ExecuteCompiler"] + if not execute_compiler.empty: + # Get the first ExecuteCompiler event + event = execute_compiler.iloc[0] + if "ts" in event: + metadata.execute_compiler_ts = event["ts"] + if "dur" in event: + metadata.execute_compiler_dur = event["dur"] + + # Process events to find the main source file being compiled + if "name" in df.columns and "arg_detail" in df.columns: + # Look for ParseDeclarationOrFunctionDefinition events with .cpp or .c files + source_extensions = (".cpp", ".cc", ".cxx", ".c") + parse_events = df[df["name"] == "ParseDeclarationOrFunctionDefinition"] + + for _, event in parse_events.iterrows(): + detail = event.get("arg_detail", "") + if detail: + # Extract file path (may include line:column info) + file_path = str(detail).split(":")[0] + + # Check if it's a source file (not a header) + if any(file_path.endswith(ext) for ext in source_extensions): + metadata.source_file = _normalize_source_path(file_path) + break + + # Compute derived fields + if metadata.execute_compiler_dur is not None: + metadata.total_wall_time_us = metadata.execute_compiler_dur + metadata.total_wall_time_s = metadata.execute_compiler_dur / 1_000_000.0 + + # Calculate wall clock times if we have the necessary data + if ( + metadata.beginning_of_time is not None + and metadata.execute_compiler_ts is not None + and metadata.execute_compiler_dur is not None + ): + metadata.wall_start_time = ( + metadata.beginning_of_time + metadata.execute_compiler_ts + ) + metadata.wall_end_time = ( + metadata.wall_start_time + metadata.execute_compiler_dur + ) + + # Convert to human-readable datetime strings + try: + start_dt = datetime.fromtimestamp(metadata.wall_start_time / 1_000_000.0) + end_dt = datetime.fromtimestamp(metadata.wall_end_time / 1_000_000.0) + metadata.wall_start_datetime = start_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[ + :-3 + ] + metadata.wall_end_datetime = end_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + except (OSError, ValueError): + # Handle invalid timestamps gracefully + pass + + return metadata diff --git a/script/analyze_build/trace_analysis/phase_breakdown.py b/script/analyze_build/trace_analysis/phase_breakdown.py new file mode 100644 index 0000000000..773ba06622 --- /dev/null +++ b/script/analyze_build/trace_analysis/phase_breakdown.py @@ -0,0 +1,354 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Phase breakdown analysis for Clang -ftime-trace data. + +This module provides hierarchical breakdown of compilation phases using +the pre-aggregated "Total" events from Clang's -ftime-trace output. + +The data is returned as a PhaseBreakdown object with rich display and +analysis capabilities optimized for Jupyter notebooks. +""" + +import pandas as pd +from collections import namedtuple +from typing import Optional + + +# Lightweight namedtuple for iteration +Phase = namedtuple("Phase", ["name", "depth", "duration", "duration_ms", "percentage"]) + + +class PhaseBreakdown: + """ + Wrapper for compilation phase breakdown with notebook-friendly API. + + Provides hierarchical view of compilation phases from Clang -ftime-trace, + with rich display, filtering, and visualization capabilities. + + Examples: + >>> breakdown = get_phase_breakdown(df) + >>> + >>> # Display in Jupyter + >>> breakdown + >>> + >>> # Access specific phases + >>> breakdown['InstantiateFunction'] + >>> breakdown.frontend + >>> breakdown.backend + >>> + >>> # Get metrics + >>> print(f"Total: {breakdown.total_ms:.1f}ms") + >>> + >>> # Top N analysis + >>> breakdown.top(10) + >>> breakdown.frontend.top(5) + >>> + >>> # Visualization + >>> import plotly.express as px + >>> data = breakdown.to_plotly() + >>> fig = px.sunburst(**data) + >>> fig.show() + >>> + >>> # Iteration + >>> for phase in breakdown: + >>> print(f"{phase.name}: {phase.duration_ms:.1f}ms") + """ + + def __init__(self, df: pd.DataFrame): + """ + Initialize from phase breakdown DataFrame. + + Args: + df: DataFrame with columns name, parent, depth, duration + """ + if df.empty: + self._df = pd.DataFrame(columns=["name", "parent", "depth", "duration"]) + self._total_time = 0 + else: + self._df = df + self._total_time = self._get_total_time() + + def __repr__(self) -> str: + """Simple text representation for console.""" + if self._df.empty: + return "PhaseBreakdown(empty)" + n_phases = len(self._df) + return f"PhaseBreakdown({n_phases} phases, {self._total_time:.1f}ms total)" + + def _repr_html_(self) -> str: + """Rich HTML representation for Jupyter notebooks.""" + if self._df.empty: + return "
PhaseBreakdown(empty)
" + return self.to_dataframe()._repr_html_() + + @property + def df(self) -> pd.DataFrame: + """ + Access underlying DataFrame. + + Returns: + DataFrame with columns name, parent, depth, duration + """ + return self._df + + def to_dataframe(self, show_percentages: bool = True) -> pd.DataFrame: + """ + Format as DataFrame for display. + + Creates a nicely formatted DataFrame suitable for Jupyter notebook display. + + Args: + show_percentages: Include percentage of total time + + Returns: + DataFrame with formatted columns + """ + return self._format_dataframe(show_percentages) + + def to_plotly(self) -> dict: + """ + Convert to plotly hierarchical visualization format. + + Returns a dictionary with data_frame, values, and path that can be directly + used with plotly.express sunburst, treemap, or icicle charts. + + Returns: + Dictionary with keys: data_frame, values, path, branchvalues + + Example: + >>> data = breakdown.to_plotly() + >>> import plotly.express as px + >>> + >>> # Create sunburst chart + >>> fig = px.sunburst(**data) + >>> fig.show() + >>> + >>> # Create treemap chart + >>> fig = px.treemap(**data) + >>> fig.show() + >>> + >>> # Create icicle chart + >>> fig = px.icicle(**data) + >>> fig.show() + """ + return self._build_plotly_data() + + # Internal helper methods + + def _get_total_time(self) -> int: + """Get total time from root ExecuteCompiler event.""" + root = self._df[self._df["depth"] == 0] + if root.empty: + return 0 + return int(root.iloc[0]["duration"]) + + def _format_dataframe(self, show_percentages: bool) -> pd.DataFrame: + """Format phase breakdown as DataFrame.""" + if self._df.empty: + return pd.DataFrame() + + display_rows = [] + for _, row in self._df.iterrows(): + duration_ms = row["duration"] / 1000.0 + display_row = { + "Name": row["name"], + "Parent": row["parent"] if row["parent"] else "(root)", + "Depth": row["depth"], + "Duration (ms)": duration_ms, + } + if show_percentages and self._total_time > 0: + pct = row["duration"] / self._total_time * 100 + display_row["% of Total"] = pct + display_rows.append(display_row) + + display_df = pd.DataFrame(display_rows) + + if show_percentages: + display_df["% of Total"] = display_df["% of Total"].round(1) + + return display_df + + def _build_plotly_data(self) -> dict: + """Convert to plotly hierarchical visualization format.""" + return { + "data_frame": self._df, + "names": "name", + "parents": "parent", + "values": "duration", + "branchvalues": "total", + } + + +# Hierarchical phase specification +# There are over 100 totals in the JSON file, but a lot of them overlap. +# If the children total more than their parent, we will throw a ValueError. +# +# The hierarchy is specified as a nested dictionary where: +# - Keys are phase names (matching "Total " events in the trace) +# - Values are dictionaries of child phases (or empty dict {} for leaf nodes) +# - Empty string "" as a key means "calculate Other as residual" +# +# This structure supports arbitrary nesting depth. +PHASE_HIERARCHY = { + "ExecuteCompiler": { + "Frontend": { + "InstantiateFunction": {}, + }, + "Backend": { + "Optimizer": {}, + "CodeGenPasses": {}, + }, + } +} + + +def get_phase_breakdown(df: pd.DataFrame) -> PhaseBreakdown: + """ + Get hierarchical breakdown of compilation phases. + + Returns a PhaseBreakdown object with rich display and analysis methods, + using the pre-aggregated "Total" events from Clang's -ftime-trace output + for accurate statistics. + + All durations are in microseconds. + + The hierarchy is defined by the PHASE_HIERARCHY constant and supports + arbitrary nesting depth. The tree is traversed recursively to build + the phase breakdown. + + Args: + df: DataFrame from parse_file() + + Returns: + PhaseBreakdown object with rich display and analysis methods + + Raises: + ValueError: If required Total events are missing or if calculated + "Other" values are negative (indicating data inconsistency) + + Examples: + >>> df = parse_file('trace.json') + >>> breakdown = get_phase_breakdown(df) + >>> + >>> # Display in Jupyter (automatic) + >>> breakdown + >>> + >>> # Get total compilation time + >>> print(f"Total: {breakdown.total_ms:.1f}ms") + >>> + >>> # Access specific phases + >>> breakdown['InstantiateFunction'] + >>> breakdown.frontend + >>> breakdown.backend.top(5) + >>> + >>> # Visualize + >>> import plotly.express as px + >>> data = breakdown.to_plotly() + >>> fig = px.sunburst(**data) + >>> fig.show() + """ + if "name" not in df.columns or "dur" not in df.columns: + raise ValueError("DataFrame missing required 'name' or 'dur' columns") + + # Pre-filter to Total events for efficient lookup + total_events = df[df["name"].str.startswith("Total ", na=False)].copy() + total_events["phase"] = total_events["name"].str.removeprefix("Total ") + + def get_duration(phase_name: str) -> Optional[int]: + """Get duration in microseconds from a Total event.""" + matches = total_events[total_events["phase"] == phase_name] + if matches.empty: + return None + return int(matches.iloc[0]["dur"]) + + def process_node( + node_name: str, + parent_name: str, + depth: int, + children_spec: dict, + ) -> list[dict]: + """ + Recursively process a node and its children in the phase hierarchy. + + Args: + node_name: Name of the current phase node + parent_name: Name of the parent phase (empty string for root) + depth: Current depth in the tree (0 for root) + children_spec: Dictionary of child phases to process + + Returns: + List of row dictionaries for this node and all descendants + + Raises: + ValueError: If phase not found or children exceed parent duration + """ + # Get duration for this node + node_duration = get_duration(node_name) + if node_duration is None: + raise ValueError(f"No Total {node_name} event found in trace") + + # Add current node + rows = [ + { + "name": node_name, + "parent": parent_name, + "depth": depth, + "duration": node_duration, + } + ] + + if not children_spec: + return rows + + # Process all children recursively + children_total = 0 + for child_name, grandchildren_spec in children_spec.items(): + if child_name == "": + # Empty string means "Other" - skip for now, calculate as residual + continue + + # Recursively process this child and its descendants + child_rows = process_node( + child_name, node_name, depth + 1, grandchildren_spec + ) + rows.extend(child_rows) + + # Track total duration of direct children only (not grandchildren) + children_total += child_rows[0]["duration"] + + # Calculate and add "Other" if there's unaccounted time + other_duration = node_duration - children_total + if other_duration < 0: + raise ValueError( + f"{node_name} children total ({children_total}) " + f"exceeds parent total ({node_duration})" + ) + + if other_duration > 0: + rows.append( + { + "name": f"{node_name}_Other", + "parent": node_name, + "depth": depth + 1, + "duration": other_duration, + } + ) + + return rows + + # Start recursive traversal from root + root_name = "ExecuteCompiler" + if root_name not in PHASE_HIERARCHY: + raise ValueError(f"Root phase '{root_name}' not found in PHASE_HIERARCHY") + + all_rows = process_node( + root_name, + "", # Root has no parent + 0, # Root is at depth 0 + PHASE_HIERARCHY[root_name], + ) + + breakdown_df = pd.DataFrame(all_rows) + return PhaseBreakdown(breakdown_df) diff --git a/script/analyze_build/trace_analysis/template_analysis.py b/script/analyze_build/trace_analysis/template_analysis.py new file mode 100644 index 0000000000..ef483f6f53 --- /dev/null +++ b/script/analyze_build/trace_analysis/template_analysis.py @@ -0,0 +1,80 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Template instantiation analysis for Clang -ftime-trace data. + +This module provides specialized functions for analyzing C++ template +instantiation costs from Clang's -ftime-trace output. +""" + +import pandas as pd +from .template_parser import parse_template_detail + + +def get_template_instantiation_events(df: pd.DataFrame) -> pd.DataFrame: + """ + Filter to template instantiation events and parse arg_detail into structured columns. + + Returns events for: + - InstantiateFunction: Function template instantiations + - InstantiateClass: Class template instantiations + + The returned DataFrame includes parsed columns from arg_detail: + - namespace: Top-level namespace (e.g., 'std', 'ck') + - template_name: Template name without parameters + - full_qualified_name: Full namespace::template_name + - param_count: Number of template parameters + - is_ck_type: Boolean indicating if this is a CK library type + - is_nested: Boolean indicating if contains nested templates + + Args: + df: DataFrame from parse_file() + + Returns: + Filtered DataFrame containing template instantiation events with parsed columns + + Example: + >>> df = parse_file('trace.json') + >>> templates = get_template_instantiation_events(df) + >>> templates.sort_values('dur', ascending=False).head(10) + >>> # Filter to CK types only + >>> ck_templates = templates[templates['is_ck_type']] + >>> # Group by template name + >>> templates.groupby('template_name')['dur'].sum() + """ + # Filter to template instantiation events + filtered_df = ( + df[ + df["name"].isin( + [ + "InstantiateClass", + "InstantiateFunction", + ] + ) + ] + .drop( + columns=[ + "arg_avg ms", + "arg_count", + "arg_name", + "cat", + "id", + "ph", + "pid", + "tid", + ] + ) + .reset_index(drop=True) + ) + + # Parse arg_detail into structured columns + parsed_data = filtered_df["arg_detail"].apply(parse_template_detail) + + # Convert list of dicts to DataFrame and join with original + parsed_df = pd.DataFrame(parsed_data.tolist()) + + # Combine with original data + result_df = pd.concat([filtered_df, parsed_df], axis=1) + + return result_df diff --git a/script/analyze_build/trace_analysis/template_parser.py b/script/analyze_build/trace_analysis/template_parser.py new file mode 100644 index 0000000000..2551465bd4 --- /dev/null +++ b/script/analyze_build/trace_analysis/template_parser.py @@ -0,0 +1,301 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Template detail string parser for C++ template instantiations. + +This module provides functions to parse the arg_detail strings from +Clang's -ftime-trace output into structured components. +""" + +import re +from typing import Dict + + +def parse_template_detail(detail_str: str) -> Dict[str, any]: + """ + Parse a template detail string into structured components. + + Args: + detail_str: The arg_detail string from -ftime-trace + + Returns: + Dictionary with parsed fields: + - namespace: Top-level namespace (e.g., 'std', 'ck') + - template_name: Template name without parameters + - full_qualified_name: Full namespace::template_name + - param_count: Number of template parameters + - is_ck_type: Boolean indicating if this is a CK library type + - is_nested: Boolean indicating if contains nested templates + + Example: + >>> parse_template_detail('std::basic_string') + { + 'namespace': 'std', + 'template_name': 'basic_string', + 'full_qualified_name': 'std::basic_string', + 'param_count': 1, + 'is_ck_type': False, + 'is_nested': False + } + """ + # Handle empty or invalid strings + if not detail_str or not isinstance(detail_str, str): + return _empty_result() + + # Remove surrounding quotes if present + detail_str = detail_str.strip('"') + + # Extract components + namespace = extract_namespace(detail_str) + template_name = extract_template_name(detail_str) + full_qualified_name = extract_full_qualified_name(detail_str) + param_count = count_template_params(detail_str) + is_ck = is_ck_template(detail_str) + is_nested = is_nested_template(detail_str) + + return { + "namespace": namespace, + "template_name": template_name, + "full_qualified_name": full_qualified_name, + "param_count": param_count, + "is_ck_type": is_ck, + "is_nested": is_nested, + } + + +def extract_namespace(detail_str: str) -> str: + """ + Extract the top-level namespace from a template detail string. + + Args: + detail_str: The template detail string + + Returns: + The top-level namespace, or empty string if none found + + Example: + >>> extract_namespace('std::basic_string') + 'std' + >>> extract_namespace('ck::tensor_operation::device::DeviceConv2d<...>') + 'ck' + """ + if not detail_str: + return "" + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find first :: separator + match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)::", detail_str) + if match: + return match.group(1) + + # No namespace found - check if it's a simple type + match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)", detail_str) + if match: + return match.group(1) + + return "" + + +def extract_template_name(detail_str: str) -> str: + """ + Extract the template name without namespace or parameters. + + Args: + detail_str: The template detail string + + Returns: + The template name without namespace or parameters + + Example: + >>> extract_template_name('std::basic_string') + 'basic_string' + >>> extract_template_name('ck::GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<...>') + 'GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3' + """ + if not detail_str: + return "" + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find the last component before < or end of string + # This handles nested namespaces like ck::tensor_operation::device::DeviceConv2d + match = re.search(r"::([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<|$)", detail_str) + if match: + return match.group(1) + + # No :: found, try to get name before < + match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<|$)", detail_str) + if match: + return match.group(1) + + return "" + + +def extract_full_qualified_name(detail_str: str) -> str: + """ + Extract the full qualified name (namespace::...::template_name). + + Args: + detail_str: The template detail string + + Returns: + The full qualified name without template parameters + + Example: + >>> extract_full_qualified_name('std::basic_string') + 'std::basic_string' + >>> extract_full_qualified_name('ck::tensor_operation::device::DeviceConv2d<...>') + 'ck::tensor_operation::device::DeviceConv2d' + """ + if not detail_str: + return "" + + # Remove quotes + detail_str = detail_str.strip('"') + + # Match everything up to the first < or end of string + match = re.match(r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\s*(?:<|$)", detail_str) + if match: + return match.group(1) + + return "" + + +def count_template_params(detail_str: str) -> int: + """ + Count the number of top-level template parameters. + + This counts commas at the top level of template brackets, + not commas inside nested templates. + + Args: + detail_str: The template detail string + + Returns: + Number of template parameters, or 0 if not a template + + Example: + >>> count_template_params('std::basic_string') + 1 + >>> count_template_params('std::tuple') + 3 + """ + if not detail_str or "<" not in detail_str: + return 0 + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find the template parameter section + start = detail_str.find("<") + if start == -1: + return 0 + + # Track bracket depth to only count top-level commas + depth = 0 + param_count = 1 # Start with 1 (if there's a <, there's at least one param) + in_template = False + + for i in range(start, len(detail_str)): + char = detail_str[i] + + if char == "<": + depth += 1 + in_template = True + elif char == ">": + depth -= 1 + if depth == 0: + # We've closed the outermost template + break + elif char == "," and depth == 1: + # Top-level comma + param_count += 1 + + return param_count if in_template else 0 + + +def is_ck_template(detail_str: str) -> bool: + """ + Check if this is a CK library template. + + Args: + detail_str: The template detail string + + Returns: + True if this is a CK library type, False otherwise + + Example: + >>> is_ck_template('ck::tensor_operation::device::DeviceConv2d<...>') + True + >>> is_ck_template('std::basic_string') + False + """ + if not detail_str: + return False + + # Remove quotes + detail_str = detail_str.strip('"') + + # Check if it starts with ck:: or contains ::ck:: + return detail_str.startswith("ck::") or "::ck::" in detail_str + + +def is_nested_template(detail_str: str) -> bool: + """ + Check if this template contains nested template instantiations. + + Args: + detail_str: The template detail string + + Returns: + True if contains nested templates, False otherwise + + Example: + >>> is_nested_template('std::vector') + False + >>> is_nested_template('std::vector') + True + """ + if not detail_str or "<" not in detail_str: + return False + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find the template parameter section + start = detail_str.find("<") + if start == -1: + return False + + # Look for nested < after the first one + depth = 0 + for i in range(start, len(detail_str)): + char = detail_str[i] + + if char == "<": + depth += 1 + if depth > 1: + # Found a nested template + return True + elif char == ">": + depth -= 1 + if depth == 0: + break + + return False + + +def _empty_result() -> Dict[str, any]: + """Return an empty result dictionary with default values.""" + return { + "namespace": "", + "template_name": "", + "full_qualified_name": "", + "param_count": 0, + "is_ck_type": False, + "is_nested": False, + } From 42a731b791e72d4ea5f270be905e6fa1eb524626 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 12:28:59 -0500 Subject: [PATCH 31/42] Updating failure patterns to be more reliable and adding tests to verify they are caught in the logs --- Jenkinsfile | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index f3a597e404..712602e532 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,10 +39,10 @@ def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. def failurePatterns = [ [pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"], - [pattern: /docker login failed/, description: "Docker login failed"], + [pattern: /.docker login failed./, description: "Docker login failed"], [pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"], [pattern: /cat: .* No such file or directory/, description: "GPU not found"], - [pattern: /GPU not found/, description: "GPU not found"], + [pattern: /.GPU not found./, description: "GPU not found"], [pattern: /Could not connect to Redis at .* Connection timed out/, description: "Redis connection timed out"] ] @@ -1290,6 +1290,13 @@ pipeline { script { env.SHOULD_RUN_CI = String.valueOf(params.FORCE_CI.toBoolean() || shouldRunCICheck()) echo "SHOULD_RUN_CI: ${env.SHOULD_RUN_CI}" + // Todo: Remove test examples + echo "GPU not found" + echo "Testing GPU not found" + echo "GPU not found Testing" + echo "docker login failed" + echo "Testing docker login failed" + echo "docker login failed Testing" } } } From 786965b95ed049e7ba2f0e6f00875a2634db90f9 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 12:47:27 -0500 Subject: [PATCH 32/42] Fixing Jenkinsfile too large error --- Jenkinsfile | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 712602e532..cd7678df1a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -34,6 +34,18 @@ def checkForPattern(pattern, log) { return [found: false, matchedLine: "", context: ""] } +def testLog() { + // Todo: Remove test examples + sh """ + echo "GPU not found" + echo "Testing GPU not found" + echo "GPU not found Testing" + echo "docker login failed" + echo "Testing docker login failed" + echo "docker login failed Testing" + """ +} + // Scan the build logs for failures and send notifications. def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. @@ -1290,13 +1302,7 @@ pipeline { script { env.SHOULD_RUN_CI = String.valueOf(params.FORCE_CI.toBoolean() || shouldRunCICheck()) echo "SHOULD_RUN_CI: ${env.SHOULD_RUN_CI}" - // Todo: Remove test examples - echo "GPU not found" - echo "Testing GPU not found" - echo "GPU not found Testing" - echo "docker login failed" - echo "Testing docker login failed" - echo "docker login failed Testing" + testLog() } } } From 95768d1b22697488f793ab90fbc7ca8e241aa6e7 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 13:02:25 -0500 Subject: [PATCH 33/42] Adding forcing failure to test notifications --- Jenkinsfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Jenkinsfile b/Jenkinsfile index cd7678df1a..5e1a5af3e4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -44,6 +44,7 @@ def testLog() { echo "Testing docker login failed" echo "docker login failed Testing" """ + error("Forcing failure to test notifications") } // Scan the build logs for failures and send notifications. From 58e1d032441fed82d33240f132168ad94bcba476 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 13:56:47 -0500 Subject: [PATCH 34/42] Removing working cases to test other failure examples --- Jenkinsfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 5e1a5af3e4..1c50698d3c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,10 +39,8 @@ def testLog() { sh """ echo "GPU not found" echo "Testing GPU not found" - echo "GPU not found Testing" echo "docker login failed" echo "Testing docker login failed" - echo "docker login failed Testing" """ error("Forcing failure to test notifications") } From 6c596b95535fffcacc2d4fadb8199ab5d00d7853 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 14:21:06 -0500 Subject: [PATCH 35/42] Testing a pattern to support all text variations --- Jenkinsfile | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 1c50698d3c..5ae56929dd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,8 +39,10 @@ def testLog() { sh """ echo "GPU not found" echo "Testing GPU not found" + echo "GPU not found Testing" echo "docker login failed" echo "Testing docker login failed" + echo "docker login failed Testing" """ error("Forcing failure to test notifications") } @@ -50,10 +52,10 @@ def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. def failurePatterns = [ [pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"], - [pattern: /.docker login failed./, description: "Docker login failed"], + [pattern: /(.*)docker login failed(.*)/, description: "Docker login failed"], [pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"], [pattern: /cat: .* No such file or directory/, description: "GPU not found"], - [pattern: /.GPU not found./, description: "GPU not found"], + [pattern: /(.*)GPU not found(.*)/, description: "GPU not found"], [pattern: /Could not connect to Redis at .* Connection timed out/, description: "Redis connection timed out"] ] From 1397924c21603123c14d0db3242532eff666eae2 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 14:25:21 -0500 Subject: [PATCH 36/42] Removed working tests. Validating remaining tests. --- Jenkinsfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 5ae56929dd..d860dc0fca 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -37,10 +37,8 @@ def checkForPattern(pattern, log) { def testLog() { // Todo: Remove test examples sh """ - echo "GPU not found" echo "Testing GPU not found" echo "GPU not found Testing" - echo "docker login failed" echo "Testing docker login failed" echo "docker login failed Testing" """ From 402f21d0a6ccf22c64f252f84768e046690b8810 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 14:27:18 -0500 Subject: [PATCH 37/42] Removed working tests. Validating remaining tests. --- Jenkinsfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index d860dc0fca..49949d8851 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -37,9 +37,7 @@ def checkForPattern(pattern, log) { def testLog() { // Todo: Remove test examples sh """ - echo "Testing GPU not found" echo "GPU not found Testing" - echo "Testing docker login failed" echo "docker login failed Testing" """ error("Forcing failure to test notifications") From 8654c0628f83261d3dd64cfb4ec80e9dd2b29fa5 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 14:29:13 -0500 Subject: [PATCH 38/42] Finished testing failure types. Removed testing code. --- Jenkinsfile | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 49949d8851..1a8be258bd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -34,15 +34,6 @@ def checkForPattern(pattern, log) { return [found: false, matchedLine: "", context: ""] } -def testLog() { - // Todo: Remove test examples - sh """ - echo "GPU not found Testing" - echo "docker login failed Testing" - """ - error("Forcing failure to test notifications") -} - // Scan the build logs for failures and send notifications. def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. @@ -1299,7 +1290,6 @@ pipeline { script { env.SHOULD_RUN_CI = String.valueOf(params.FORCE_CI.toBoolean() || shouldRunCICheck()) echo "SHOULD_RUN_CI: ${env.SHOULD_RUN_CI}" - testLog() } } } From cc75948d1c7f732d102c8e31dc007a2ccd07761f Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Mon, 26 Jan 2026 23:50:15 +0100 Subject: [PATCH 39/42] [CK_BUILDER] conv bwd weight testing (#3618) * ck-builder: restructure testing conv In order to prepare for bwd of conv testing, this commit moves some files and types around so that we can reuse ckt::Args for both forward and backwards convolution. * ck-builder: decouple fwd_ck.hpp and fwd_reference.hpp from fwd.hpp This will allow us to more easily include fwd.hpp from backwards definitions, which is required for initializing bwd values. * ck-builder: fix layout of test_ckb_conv_bwd_weight_xdl_cshuffle_v3 Turns out that the supplied layout isn't actually supported... * ck-builder: ck and reference conv integration for bwd weight * ck-builder: ck bwd weight execution test * ck-builder: ckt::run support for ck-tile bwd weight * ck-builder: ck tile bwd weight execution test * ck-builder: extra debug printing in MatchesReference * ck-builder: make ckt::run return RunResult This type is more convenient than std::tuple, as it will allow us to use google test matchers with this in the future. * ck-builder: RunResult matcher Using EXPECT_THAT(..., SuccessfulRun()) will generate a check and a nice error message about how and why running an algorithm failed. * ck-builder: doc fixes * ck-builder: add missing headers --- .../testing/{conv_fwd.hpp => conv/args.hpp} | 64 +--- .../builder/testing/conv/bwd_weight.hpp | 71 +++++ .../builder/testing/conv/bwd_weight_ck.hpp | 276 ++++++++++++++++++ .../ck_tile.hpp} | 92 ++++-- .../ck_tile/builder/testing/conv/fwd.hpp | 69 +++++ .../{conv_fwd_ck.hpp => conv/fwd_ck.hpp} | 58 ++-- .../builder/testing/conv/reference.hpp | 137 +++++++++ .../builder/testing/conv_fwd_reference.hpp | 88 ------ .../builder/testing/tensor_initialization.hpp | 1 + .../ck_tile/builder/testing/testing.hpp | 62 +++- .../builder/testing/testing_reflect.hpp | 2 + experimental/builder/test/CMakeLists.txt | 2 +- ...st_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp | 59 +++- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 13 +- .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 94 ++++-- .../conv/ck_tile/test_ckb_conv_fwd_e2e.cpp | 13 +- .../builder/test/test_testing_utils.cpp | 17 ++ experimental/builder/test/testing_utils.cpp | 18 ++ experimental/builder/test/testing_utils.hpp | 32 ++ .../builder/test/unit_conv_fwd_testing.cpp | 2 +- experimental/builder/test/unit_validation.cpp | 5 +- .../instances/instance_includes.inc | 3 +- .../instances/instance_run.inc | 8 +- .../grouped_convolution_forward_tile_algs.hpp | 9 +- .../grouped_convolution_signatures.hpp | 2 +- .../src/profile_grouped_conv_fwd_tile.cpp | 2 +- .../test_grouped_convnd_fwd_tile.cpp | 2 +- 27 files changed, 939 insertions(+), 262 deletions(-) rename experimental/builder/include/ck_tile/builder/testing/{conv_fwd.hpp => conv/args.hpp} (82%) create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp rename experimental/builder/include/ck_tile/builder/testing/{conv_fwd_ck_tile.hpp => conv/ck_tile.hpp} (52%) create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp rename experimental/builder/include/ck_tile/builder/testing/{conv_fwd_ck.hpp => conv/fwd_ck.hpp} (73%) create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp delete mode 100644 experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/args.hpp similarity index 82% rename from experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp rename to experimental/builder/include/ck_tile/builder/testing/conv/args.hpp index 51edf41cba..eba6771964 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/args.hpp @@ -7,26 +7,25 @@ #include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/builder/testing/testing.hpp" -#include "ck_tile/builder/testing/testing_reflect.hpp" #include "ck_tile/builder/testing/filter_extent.hpp" -#include "ck_tile/builder/testing/tensor_buffer.hpp" -#include "ck_tile/host/convolution_parameter.hpp" -#include "ck_tile/builder/testing/tensor_initialization.hpp" #include "ck_tile/builder/testing/tensor_descriptor.hpp" -#include "ck_tile/builder/testing/validation.hpp" +#include "ck_tile/host/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" /// This file implements common functionality for invoking/testing grouped /// forward convolutions created through the CK Builder API. The main item -/// of it is the ConvArgs structure - which contains a complete description +/// of it is the Args structure - which contains a complete description /// of a convolution operation. /// /// It is not intended that this file contains implementation details for /// actually launching a convolution operation. As this can be done /// through different APIs depending on the kernel (CK, CK Tile, or a /// reference implementation), the code dealing with that is split out -/// into a separate header for each implementation. +/// into a separate header for each implementation. Nor does this file +/// deal with details for defining the data types (`Inputs` and `Outputs`) +/// for different conv directions, that is also split out into separate +/// headers to keep this one small. namespace ck_tile::builder::test { @@ -56,7 +55,7 @@ struct ConvTensorLengths /// /// @see Args template - requires ValidConvSignature && ConvDirectionIsForward + requires ValidConvSignature struct Args { constexpr static auto SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -204,53 +203,4 @@ struct Args } }; -/// @brief `Inputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see Inputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct Inputs -{ - void* input; - void* weight; - - static void reflect(const Args& args, const auto& inspect) - { - inspect("input", args.make_input_descriptor(), &Inputs::input); - inspect("weight", args.make_weight_descriptor(), &Inputs::weight); - } -}; - -/// @brief `Outputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see Outputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct Outputs -{ - void* output; - - static void reflect(const Args& args, const auto& inspect) - { - inspect("output", args.make_output_descriptor(), &Outputs::output); - } -}; - -/// @brief `init_inputs()` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see alloc_inputs() -template - requires ValidConvSignature && ConvDirectionIsForward -void init_inputs(const Args& args, Inputs inputs) -{ - init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); - init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); -} - } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp new file mode 100644 index 0000000000..ce5811c87a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp @@ -0,0 +1,71 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_initialization.hpp" +#include "ck_tile/builder/testing/testing_reflect.hpp" +#include "ck_tile/builder/testing/conv/args.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/error.hpp" + +/// This file deals with the backward weight-specific details of running grouped +/// convolution backwards weight operations. It mainly defines the data +/// structures (`Input` and `Output`), initialization, and validation. Note +/// that for this operation specifically, many of the operations are +/// implemented automatically via testing_reflect.hpp. + +namespace ck_tile::builder::test { + +/// @brief `Inputs` specialization for backwards weight convolution. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// +/// @see Inputs +template + requires ValidConvSignature && ConvDirectionIsBackwardWeight +struct Inputs +{ + void* input; + void* output; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("input", args.make_input_descriptor(), &Inputs::input); + inspect("output", args.make_output_descriptor(), &Inputs::output); + } +}; + +/// @brief `Outputs` specialization for backwards weight convolution. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// +/// @see Outputs +template + requires ValidConvSignature && ConvDirectionIsBackwardWeight +struct Outputs +{ + void* weight; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("weight", args.make_weight_descriptor(), &Outputs::weight); + } +}; + +/// @brief `init_inputs()` specialization for backwards convolution. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// +/// @see init_inputs() +template + requires ValidConvSignature && ConvDirectionIsBackwardWeight +void init_inputs(const Args& args, Inputs inputs) +{ + init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.output, args.make_output_descriptor(), -2.0f, 2.0f); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp new file mode 100644 index 0000000000..0b1ffeb707 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp @@ -0,0 +1,276 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/testing.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// bwd grouped convolution operations in old CK. The main item is the +/// `run()` function, which is the main implementation used to invoke +/// CK grouped forward convolution kernels. + +namespace ck_tile::builder::test { + +namespace detail { + +/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK. +/// +/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template , + typename Ops = factory::internal::ConvElementwiseOps> +concept CkConvBwdWeightInstance = requires(Conv& conv, + const Types::InDataType* p_a, + Types::WeiDataType* p_b, + const Types::OutDataType* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde, + ck::index_t split_k) { + requires ValidConvSignature; + requires ConvDirectionIsBackwardWeight; + + { + conv.MakeArgument(p_a, + p_b, + p_e, + // A lengths/strides + lengths, + strides, + // B lengths/strides + lengths, + strides, + // E lengths/strides + lengths, + strides, + // strides/dilations/pads + filter, + filter, + filter, + filter, + // element-wise operations. + elementwise_a, + elementwise_b, + elementwise_cde, + split_k) + }; +}; + +/// @brief Concept for checking whether a bwd weight convolution is multiple-d and +/// invoked like old CK. +/// +/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightMultipleDInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template , + typename Ops = factory::internal::ConvElementwiseOps> +concept CkConvBwdWeightMultipleDInstance = requires(Conv& conv, + const Types::InDataType* p_a, + Types::WeiDataType* p_b, + const Types::OutDataType* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde, + ck::index_t split_k) { + requires ValidConvSignature; + requires ConvDirectionIsBackwardWeight; + + { + conv.MakeArgument(p_a, + p_b, + p_e, + // TODO: Actually support multiple d + {}, + // A lengths/strides + lengths, + strides, + // B lengths/strides + lengths, + strides, + // E lengths/strides + lengths, + strides, + // TODO: Multiple D lengths/strides + {}, + {}, + // strides/dilations/pads + filter, + filter, + filter, + filter, + // element-wise operations. + elementwise_a, + elementwise_b, + elementwise_cde, + split_k) + }; +}; + +} // namespace detail + +/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept CkConvBwdWeightInstance = detail::CkConvBwdWeightInstance; + +/// @brief Concept for checking whether a bwd weight convolution is multiple-d and +/// invoked like old CK. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept CkConvBwdWeightMultipleDInstance = + detail::CkConvBwdWeightMultipleDInstance; + +/// @brief `run()` specialization for backward weight convolution and old CK. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template +[[nodiscard]] RunResult run(CkConvBwdWeightInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + using Types = factory::internal::ConvTensorDataTypes; + + constexpr auto spatial_dim = SIGNATURE.spatial_dim; + + const auto copy = [](const auto& src, auto& dst) { + std::copy(src.begin(), src.end(), dst.begin()); + }; + + const auto to_ck_lengths = [&](const auto& src) { + std::array result; + copy(src, result); + return result; + }; + + const auto to_ck_extent = [&](const auto& extent) { + std::array result; + copy(extent, result); + return result; + }; + + const auto param = args.to_ck_conv_param(); + + const auto input_desc = args.make_input_descriptor(); + const auto weight_desc = args.make_weight_descriptor(); + const auto output_desc = args.make_output_descriptor(); + + auto ck_args = conv.MakeArgument(static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + to_ck_lengths(input_desc.get_lengths()), + to_ck_lengths(input_desc.get_strides()), + to_ck_lengths(weight_desc.get_lengths()), + to_ck_lengths(weight_desc.get_strides()), + to_ck_lengths(output_desc.get_lengths()), + to_ck_lengths(output_desc.get_strides()), + to_ck_extent(param.conv_filter_strides_), + to_ck_extent(param.conv_filter_dilations_), + to_ck_extent(param.input_left_pads_), + to_ck_extent(param.input_right_pads_), + args.a_elementwise_op, + args.b_elementwise_op, + args.cde_elementwise_op, + args.k_batch); + + if(!conv.IsSupportedArgument(ck_args)) + return RunResult::not_supported("invalid ck arguments"); + + return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {})); +} + +/// @brief `run()` specialization for backward weight convolution and old CK. +/// +/// This overload is specialized for Multiple-D. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template +[[nodiscard]] RunResult run(CkConvBwdWeightMultipleDInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + using Types = factory::internal::ConvTensorDataTypes; + + constexpr auto spatial_dim = SIGNATURE.spatial_dim; + + const auto copy = [](const auto& src, auto& dst) { + std::copy(src.begin(), src.end(), dst.begin()); + }; + + const auto to_ck_lengths = [&](const auto& src) { + std::array result; + copy(src, result); + return result; + }; + + const auto to_ck_extent = [&](const auto& extent) { + std::array result; + copy(extent, result); + return result; + }; + + const auto param = args.to_ck_conv_param(); + + const auto input_desc = args.make_input_descriptor(); + const auto weight_desc = args.make_weight_descriptor(); + const auto output_desc = args.make_output_descriptor(); + + auto ck_args = conv.MakeArgument(static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + {}, // TODO + to_ck_lengths(input_desc.get_lengths()), + to_ck_lengths(input_desc.get_strides()), + to_ck_lengths(weight_desc.get_lengths()), + to_ck_lengths(weight_desc.get_strides()), + to_ck_lengths(output_desc.get_lengths()), + to_ck_lengths(output_desc.get_strides()), + {}, // TODO + {}, // TODO + to_ck_extent(param.conv_filter_strides_), + to_ck_extent(param.conv_filter_dilations_), + to_ck_extent(param.input_left_pads_), + to_ck_extent(param.input_right_pads_), + args.a_elementwise_op, + args.b_elementwise_op, + args.cde_elementwise_op, + args.k_batch); + + if(!conv.IsSupportedArgument(ck_args)) + return RunResult::not_supported("invalid ck arguments"); + + return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {})); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp similarity index 52% rename from experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp rename to experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp index a8f6825524..133d7d69b7 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp @@ -3,9 +3,8 @@ #pragma once -#include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/testing/testing.hpp" #include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/grouped_convolution.hpp" #include @@ -28,9 +27,39 @@ namespace detail { /// namespace. template concept CkTileConvInstance = requires(Conv&) { + requires ValidConvSignature; { Conv::BlockSize() }; }; +template +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, + const Args& args, + InDataType* input, + WeiDataType* weight, + OutDataType* output, + const ck_tile::stream_config s_conf) +{ + using Conv = std::remove_reference_t; + const auto param = args.to_ck_tile_conv_param(); + + ck_tile::GroupedConvHostArgs + host_args(param, input, weight, {}, output, args.k_batch); + + auto kargs = Conv::MakeKernelArgs(host_args); + + const dim3 grids = Conv::GridSize(kargs); + const dim3 blocks = Conv::BlockSize(); + + if(!Conv::IsSupportedArgument(kargs)) + return RunResult::not_supported("unsupported ck_tile arguments"); + + constexpr index_t minimum_occupancy = + Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2; + + return RunResult::from_runtime(ck_tile::launch_kernel( + s_conf, ck_tile::make_kernel(conv, grids, blocks, 0, kargs))); +} + } // namespace detail /// @brief Concept for checking whether a convolution is invoked like CK Tile. @@ -48,44 +77,45 @@ concept CkTileConvInstance = detail::CkTileConvInstance; /// @brief `run()` specialization for forward convolution and CK Tile. /// /// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments weren't actually valid for the -/// operation. This should be caught and reported by the testing framework. -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f if s_conf time_kernel is false). +/// @returns RunResult about how the operation completed (or not). /// /// @see run() template - requires ValidConvSignature && ConvDirectionIsForward -std::tuple run(CkTileConvInstance auto& conv, + requires ConvDirectionIsForward +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, const Args& args, const Inputs& inputs, const Outputs& outputs, const ck_tile::stream_config s_conf = {}) { - using Conv = std::remove_reference_t; - const auto param = args.to_ck_tile_conv_param(); + return detail::run(conv, + args, + static_cast(inputs.input), + static_cast(inputs.weight), + static_cast(outputs.output), + s_conf); +} - ck_tile::GroupedConvFwdHostArgs<> host_args( - param, inputs.input, inputs.weight, {}, outputs.output, args.k_batch); - - auto kargs = Conv::MakeKernelArgs(host_args); - - const dim3 grids = Conv::GridSize(kargs); - const dim3 blocks = Conv::BlockSize(); - - if(!Conv::IsSupportedArgument(kargs)) - { - std::cout << "Not supported!"; - return std::make_tuple(false, 0.f); - } - - constexpr index_t minimum_occupancy = - Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2; - - return std::make_tuple( - true, - ck_tile::launch_kernel( - s_conf, ck_tile::make_kernel(conv, grids, blocks, 0, kargs))); +/// @brief `run()` specialization for backwards weight convolution and CK Tile. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template + requires ConvDirectionIsBackwardWeight +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs, + const ck_tile::stream_config s_conf = {}) +{ + return detail::run(conv, + args, + static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + s_conf); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp new file mode 100644 index 0000000000..b81892c91e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_initialization.hpp" +#include "ck_tile/builder/testing/testing_reflect.hpp" +#include "ck_tile/builder/testing/conv/args.hpp" + +/// This file deals with the forward-specific details of running grouped +/// convolution forward operations. It mainly defines the data structures +/// (`Input` and `Output`), initialization, and validation. Note that +/// for this operation specifically, many of the operations are implemented +/// automatically via testing_reflect.hpp. + +namespace ck_tile::builder::test { + +/// @brief `Inputs` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see Inputs +template + requires ValidConvSignature && ConvDirectionIsForward +struct Inputs +{ + void* input; + void* weight; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("input", args.make_input_descriptor(), &Inputs::input); + inspect("weight", args.make_weight_descriptor(), &Inputs::weight); + } +}; + +/// @brief `Outputs` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see Outputs +template + requires ValidConvSignature && ConvDirectionIsForward +struct Outputs +{ + void* output; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("output", args.make_output_descriptor(), &Outputs::output); + } +}; + +/// @brief `init_inputs()` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see init_inputs() +template + requires ValidConvSignature && ConvDirectionIsForward +void init_inputs(const Args& args, Inputs inputs) +{ + init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/fwd_ck.hpp similarity index 73% rename from experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp rename to experimental/builder/include/ck_tile/builder/testing/conv/fwd_ck.hpp index f911dca21c..5eca79508c 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/fwd_ck.hpp @@ -3,14 +3,14 @@ #pragma once -#include "ck_tile/builder/testing/conv_fwd.hpp" -#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/builder/testing/testing.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/host/kernel_launch.hpp" #include #include /// This file contains the implementation details for invoking/testing -/// grouped convolution operations in old CK. The main item is the +/// fwd grouped convolution operations in old CK. The main item is the /// `run()` function, which is the main implementation used to invoke /// CK grouped forward convolution kernels. @@ -18,10 +18,9 @@ namespace ck_tile::builder::test { namespace detail { -/// @brief Concept for checking whether this is the reference convolution -/// implementation. +/// @brief Concept for checking whether a fwd convolution is invoked like old CK. /// -/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except +/// This is the same as `::ck_tile::builder::test::CkConvFwdInstance`, except /// with some utility aliases. For that reason, its moved to this detail /// namespace. template > -concept CkConvInstance = requires(Conv& conv, - // TODO: This should be changed depending on IsMultiA etc. - // Currently that is not yet supported elsewhere anyway. - const void* p_a, - const void* p_b, - void* p_e, - std::array lengths, - std::array strides, - std::array filter, - Ops::InElementwiseOp elementwise_a, - Ops::WeiElementwiseOp elementwise_b, - Ops::OutElementwiseOp elementwise_cde) { +concept CkConvFwdInstance = requires(Conv& conv, + // TODO: This should be changed depending on IsMultiA etc. + // Currently that is not yet supported elsewhere anyway. + const void* p_a, + const void* p_b, + void* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde) { + requires ValidConvSignature; + requires ConvDirectionIsForward; + { conv.MakeArgument(p_a, p_b, @@ -73,7 +75,7 @@ concept CkConvInstance = requires(Conv& conv, } // namespace detail -/// @brief Concept for checking whether a convolution is invoked like old CK. +/// @brief Concept for checking whether a fwd convolution is invoked like old CK. /// /// This concept is used to tell whether a convolution implementation is /// likely to be an "old CK" implementation - that is, whether we should @@ -83,20 +85,17 @@ concept CkConvInstance = requires(Conv& conv, /// - SIGNATURE is the operation signature. /// - Conv is a convolution instance created by the CK Builder API. template -concept CkConvInstance = detail::CkConvInstance; +concept CkConvFwdInstance = detail::CkConvFwdInstance; /// @brief `run()` specialization for forward convolution and old CK. /// /// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments weren't actually valid for the -/// operation. This should be caught and reported by the testing framework. -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f if s_conf time_kernel is false). +/// @returns RunResult about how the operation completed (or not). /// /// @see run() template requires ValidConvSignature && ConvDirectionIsForward -std::tuple run(CkConvInstance auto& conv, +[[nodiscard]] RunResult run(CkConvFwdInstance auto& conv, const Args& args, const Inputs& inputs, const Outputs& outputs, @@ -126,6 +125,9 @@ std::tuple run(CkConvInstance auto& conv, const auto weight_desc = args.make_weight_descriptor(); const auto output_desc = args.make_output_descriptor(); + if(args.k_batch != 1) + return RunResult::not_supported("ck fwd does not support k_batch != 1"); + auto ck_args = conv.MakeArgument(inputs.input, inputs.weight, {}, @@ -147,11 +149,9 @@ std::tuple run(CkConvInstance auto& conv, args.cde_elementwise_op); if(!conv.IsSupportedArgument(ck_args)) - { - std::cout << "invalid argument" << std::endl; - } + return RunResult::not_supported("unsupported ck arguments"); - return std::make_tuple(true, conv.MakeInvoker().Run(ck_args, s_conf)); + return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, s_conf)); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp new file mode 100644 index 0000000000..169d0741ff --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp @@ -0,0 +1,137 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/testing.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// grouped convolution operations using the reference implementation. +/// The main item is the `run()` function, which is the primary way to +/// invoke the reference execution mechanism. +/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`, +/// but its made specific to the reference implementation, which is +/// invoked in a slightly different way. + +namespace ck_tile::builder::test { + +namespace detail { + +/// @brief Concept for checking whether this is the reference convolution +/// implementation. +/// +/// This concept is used to tell whether a convolution implementation is +/// likely to be the reference implementation - that is, whether we should +/// invoke it like the reference kernel. This is mainly used with `run()` to +/// differentiate which implementation that should be invoked. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +/// - InDataType, WeiDataType, OutDataType are the types of the respective tensors. +template +concept RefConvInstance = requires(Conv& conv, + InDataType* input, + WeiDataType* weight, + OutDataType* output, + ck::utils::conv::ConvParam param) { + requires ValidConvSignature; + { conv.Run(input, weight, output, param) }; +}; + +/// @brief Generic `run` implementation for forward/backwards reference kernels. +/// +/// @tparam SIGNATURE The signature of the operation to perform. +/// +/// @return std::tuple - whether the problem is supported and +/// kernel execution time (0.0f for reference). +/// @see run() +template +[[nodiscard]] RunResult +run(RefConvInstance auto& conv, + const Args& args, + InDataType* input, + WeiDataType* weight, + OutDataType* output) +{ + // We don't want to compute the output dims manually, just get + // them via the existing infrastructure + const auto param = args.to_ck_conv_param(); + + // TODO: The reference convolution is currently missing a few features. + // Just throw for now, but regard these as TODO items that should be resolved + // eventually. + + if(!args.make_input_descriptor().is_packed()) + return RunResult::not_supported("TODO: Support non-packed input tensor in reference conv"); + + if(!args.make_weight_descriptor().is_packed()) + return RunResult::not_supported("TODO: Support non-packed weight tensor in reference conv"); + + if(!args.make_output_descriptor().is_packed()) + return RunResult::not_supported("TODO: Support non-packed output tensor in reference conv"); + + conv.Run(input, weight, output, param); + return RunResult::from_runtime(0); // ref conv does not return a meaningful runtime. +} + +} // namespace detail + +/// @brief Concept for checking whether this is the reference convolution +/// forward implementation. +template +concept RefConvFwdInstance = + detail::RefConvInstance && + ConvDirectionIsForward; + +/// @brief `run()` specialization for forward convolution and the reference +/// forward implementation. +/// +/// @tparam SIGNATURE The signature of the operation to perform. Must be forwards. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template + requires ValidConvSignature && + // TODO: Maybe we can unify this implementation for bwd/weight too? + // for now, just concern outselves with reference and see when the + // rest of the bwd/weight plumbing is there. + ConvDirectionIsForward +[[nodiscard]] RunResult run(RefConvFwdInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + return detail::run(conv, args, inputs.input, inputs.weight, outputs.output); +} + +/// @brief Concept for checking whether this is the reference convolution +/// backward weight implementation. +template +concept RefConvBwdWeightInstance = + detail::RefConvInstance && + ConvDirectionIsBackwardWeight; + +/// @brief `run()` specialization for forward convolution and the reference +/// backward weight implementation. +/// +/// @tparam SIGNATURE The signature of the operation to perform. Must be backwards weight. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template +[[nodiscard]] RunResult run(RefConvBwdWeightInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + return detail::run(conv, args, inputs.input, outputs.weight, inputs.output); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp deleted file mode 100644 index ff276f7c9c..0000000000 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/builder/testing/conv_fwd.hpp" -#include -#include - -/// This file contains the implementation details for invoking/testing -/// grouped convolution operations using the reference implementation. -/// The main item is the `run()` function, which is the primary way to -/// invoke the reference execution mechanism. -/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`, -/// but its made specific to the reference implementation, which is -/// invoked in a slightly different way. - -namespace ck_tile::builder::test { - -/// @brief Concept for checking whether this is the reference convolution -/// implementation. -/// -/// This concept is used to tell whether a convolution implementation is -/// likely to be the reference implementation - that is, whether we should -/// invoke it like the reference kernel. This is mainly used with `run()` to -/// differentiate which implementation that should be invoked. -/// -/// - SIGNATURE is the operation signature. -/// - Conv is a convolution instance created by the CK Builder API. -template -concept RefConvInstance = requires(Conv& conv, - const void* input, - const void* weight, - void* output, - ck::utils::conv::ConvParam param) { - { conv.Run(input, weight, output, param) }; -}; - -/// @brief `run()` specialization for forward convolution and the reference -/// implementation. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments weren't actually valid for the -/// operation. This should be caught and reported by the testing framework. -/// -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f for reference). -/// @see run() -template - requires ValidConvSignature && - // TODO: Maybe we can unify this implementation for bwd/weight too? - // for now, just concern outselves with reference and see when the - // rest of the bwd/weight plumbing is there. - ConvDirectionIsForward -std::tuple run(RefConvInstance auto& conv, - const Args& args, - const Inputs& inputs, - const Outputs& outputs) -{ - // We don't want to compute the output dims manually, just get - // them via the existing infrastructure - const auto param = args.to_ck_conv_param(); - - // TODO: The reference convolution is currently missing a few features. - // Just throw for now, but regard these as TODO items that should be resolved - // eventually. - - if(!args.make_input_descriptor().is_packed()) - { - std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - if(!args.make_weight_descriptor().is_packed()) - { - std::cout << "TODO: Support non-packed weight tensor in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - if(!args.make_output_descriptor().is_packed()) - { - std::cout << "TODO: Support non-packed output tensor in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - - conv.Run(inputs.input, inputs.weight, outputs.output, param); - return std::make_tuple(true, 0.0f); -} - -} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp index 2976e6c14b..35fc1f4ee8 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp @@ -12,6 +12,7 @@ #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" #include "ck_tile/builder/testing/type_traits.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "ck_tile/host/host_tensor.hpp" #include "ck/utility/data_type.hpp" diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index e61d7c4da5..307871b47a 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -3,7 +3,11 @@ #pragma once +#include #include +#include +#include +#include #include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "ck_tile/builder/testing/tensor_buffer.hpp" @@ -288,6 +292,57 @@ ValidationReport validate(const Args& args, Outputs actual, Outputs expected) = delete; +/// @brief This structure represents the result of a run operation. +/// +/// The structure contains multiple fields with information about +/// how the operation completed (or not). See those for more info. +struct RunResult +{ + /// If this value is not set to `std::nullopt`, there was a problem + /// while running the algorithm. In this case, the outputs are not + /// valid (though may be partially or completely overwritten), and + /// the optional contains a short debug message that indicates the + /// problem. + std::optional error = std::nullopt; + + /// The runtime of the kernel in milliseconds, if measured. Whether the + /// runtime is measured at all depends on the stream configuration + /// passed to run(). 0 if not measured or if there was an error. This + /// value is averaged over the total amount of runs actually done. Again, + /// this is usually configured via the stream config. + float runtime = 0.f; + + /// @brief Utility function for constructing a RunResult from an unsupported operation. + /// + /// @param msg A short debug message that will be included in the result. + constexpr static RunResult not_supported(std::string_view msg) + { + return RunResult{.error = std::string(msg)}; + } + + /// @brief Utility function for constructing a RunResult from an average runtime, + /// indicating a successful operation. + /// + /// @param runtime The runtime of the kernel in milliseconds. + constexpr static RunResult from_runtime(const float runtime) + { + return RunResult{.runtime = runtime}; + } + + /// @brief Returns whether this algorithm executed successfully. + /// + /// In this case there should be no message in `error`. + bool is_supported() const { return !this->error.has_value(); } +}; + +inline std::ostream& operator<<(std::ostream& os, const RunResult& result) +{ + if(result.error.has_value()) + return os << "invalid run (" << result.error.value() << ")"; + else + return os << "successful run (" << result.runtime << " ms)"; +} + /// @brief Invoke a device operation created by CK Builder. /// /// This is the main function used to invoke a particular device operation @@ -318,13 +373,14 @@ ValidationReport validate(const Args& args, /// @param outputs The output tensor data. The contents will be overwritten by /// this function. /// @param s_conf Stream config used to launch kernel. -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f if s_conf time_kernel is false). +/// @returns RunResult about how the operation completed (or not). /// /// @note This function is explicitly deleted to generate compile errors /// for missing implementations. +/// +/// @see RunResult template -std::tuple run(Operation& operation, +[[nodiscard]] RunResult run(Operation& operation, const Args& args, const Inputs& inputs, const Outputs& outputs, diff --git a/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp index 81d5b7a6f5..076b5e9751 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp @@ -5,6 +5,8 @@ #include +#include "ck_tile/builder/testing/testing.hpp" + /// testing.hpp requires developers of a type of SIGNATURE to implement /// quite a lot of functionality for each SIGNATURE. For example, next /// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 9890563859..73a682f10c 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -168,7 +168,7 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/ck/test_ckb_conv_fwd_3d_fp16.cpp conv/ck/test_ckb_conv_fwd_3d_fp32.cpp conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp - ) +) target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) set(BWD_WEIGHT_TESTS diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp index 4ad97209e5..a3f4a988ef 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -1,23 +1,30 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include "ck_tile/builder/testing/conv/bwd_weight.hpp" +#include "ck_tile/builder/testing/conv/bwd_weight_ck.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" +#include "ck_tile/host/device_prop.hpp" #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/host/device_prop.hpp" +#include "testing_utils.hpp" namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; + using enum ck_tile::builder::TensorLayout; +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1, .direction = ckb::ConvDirection::BACKWARD_WEIGHT, .data_type = ckb::DataType::BF16, .accumulation_data_type = ckb::DataType::FP32, - .input = {.config = {.layout = NGCW}}, + .input = {.config = {.layout = GNWC}}, .weight = {.config = {.layout = GKXC}}, - .output = {.config = {.layout = NGKW}}}; + .output = {.config = {.layout = GNWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{} @@ -30,14 +37,58 @@ constexpr auto ALGORITHM = using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; +using Reference = ckb::ConvBuilder::Instance; + TEST(BwdWeight_1DBf16_CShuffle_V3, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3", expected_transfer_parameters, "Filter1x1Stride1Pad0", - "NGCW,GKXC,NGKW", + "GNWC,GKXC,GNWK", "PassThrough,PassThrough,PassThrough", "Intrawave", "v2"}); } + +TEST(BwdWeight_1DBf16_CShuffle_V3, Execution) +{ + if(!ck_tile::get_device_name().starts_with("gfx9")) + { + // Note: XDL kernel + GTEST_SKIP() << "unsupported architecture"; + } + + ckt::Args args = { + .lengths = + { + .batch_size = 16, + .groups = 1, + .input_channels = 32, + .output_channels = 48, + .image = {.width = 64}, + .filter = {.width = 1}, + }, + .filter_strides = {.width = 1}, + .filter_dilation = {.width = 1}, + .input_left_pad = {.width = 0}, + .input_right_pad = {.width = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); + + ckt::init_inputs(args, inputs.get()); + + auto conv = Instance{}; + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); + + auto ref_conv = Reference{}; + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 3e5e39191e..51bc45c29b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -4,8 +4,9 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck.hpp" -#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/fwd_ck.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/host/device_prop.hpp" #include "testing_utils.hpp" @@ -14,6 +15,7 @@ namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, @@ -50,10 +52,11 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, Create) "MNKPadding"}); } -TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) +TEST(Fwd2DFp16_CShufV3_GNHWC, Execution) { if(!ck_tile::get_device_name().starts_with("gfx9")) { + // Note: XDL kernel GTEST_SKIP() << "unsupported architecture"; } @@ -91,10 +94,10 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; - ckt::run(conv, args, inputs.get(), outputs.get()); + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); auto ref_conv = Reference{}; - ckt::run(ref_conv, args, inputs.get(), reference.get()); + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); } diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index 292d852b91..60dc45545f 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -1,35 +1,47 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include "ck_tile/builder/testing/conv/bwd_weight.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" +#include "ck_tile/host/device_prop.hpp" #include "utils/ckb_conv_tile_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "testing_utils.hpp" -namespace { +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; -using namespace ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; -TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +constexpr auto SIGNATURE = cku::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; + +constexpr auto ALGORITHM = + cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(ckb::TileConvSpecialization::DEFAULT) + .with_tile_thread_block(cku::TileThreadBlock_64x64x64) + .with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(cku::TileTransfer_4x4x4) + .with_tile_optimizations(ckt::TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +using Reference = ckb::ConvBuilder::Instance; + +TEST(BwdWeight_2D_FP16_NHWGC, Create) { - constexpr ConvSignature BwdWeightConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::BACKWARD_WEIGHT, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; - - constexpr auto BwdWeightConvAlgorithm = - ConvAlgorithm_Tile_GroupedConvolutionKernel{} - .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(TileThreadBlock_64x64x64) - .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(TileTransfer_4x4x4) - .with_tile_optimizations(TileOptimizations{ - .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); - - using Builder = ConvBuilder; - run_ck_tile_test({ + cku::run_ck_tile_test({ "grouped_convolution_backward_weight", "fp16", "NHWGC_GKYXC_NHWGK", @@ -49,4 +61,38 @@ TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_ }); } -} // namespace +TEST(BwdWeight_2D_FP16_NHWGC, Execution) +{ + ckt::Args args = { + .lengths = + { + .batch_size = 2, + .groups = 4, + .input_channels = 32, + .output_channels = 48, + .image = {.width = 32, .height = 56}, + .filter = {.width = 3, .height = 3}, + }, + .filter_strides = {.width = 1, .height = 1}, + .filter_dilation = {.width = 1, .height = 1}, + .input_left_pad = {.width = 0, .height = 0}, + .input_right_pad = {.width = 0, .height = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); + + ckt::init_inputs(args, inputs.get()); + + auto conv = Instance{}; + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); + + auto ref_conv = Reference{}; + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); +} diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp index 128744dcc6..650c217b71 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp @@ -4,8 +4,8 @@ #include "utils/ckb_conv_tile_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" -#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/host/device_prop.hpp" #include "testing_utils.hpp" @@ -13,6 +13,9 @@ namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; + constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, .direction = ckb::ConvDirection::FORWARD, @@ -75,10 +78,10 @@ TEST(Fwd2DFp16_CShufV3_NHWGC, EndToEnd) ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; - ckt::run(conv, args, inputs.get(), outputs.get()); + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); auto ref_conv = Reference{}; - ckt::run(ref_conv, args, inputs.get(), reference.get()); + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); - EXPECT_THAT(outputs.get(), ck_tile::test::MatchesReference(args, reference.get())); + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); } diff --git a/experimental/builder/test/test_testing_utils.cpp b/experimental/builder/test/test_testing_utils.cpp index 43bbbd69eb..100122eef3 100644 --- a/experimental/builder/test/test_testing_utils.cpp +++ b/experimental/builder/test/test_testing_utils.cpp @@ -5,11 +5,14 @@ #include "testing_utils.hpp" +namespace ckt = ck_tile::builder::test; + using ck_tile::test::HipError; using ck_tile::test::HipSuccess; using ck_tile::test::InstanceMatcher; using ck_tile::test::InstanceSet; using ck_tile::test::StringEqWithDiff; +using ck_tile::test::SuccessfulRun; TEST(InstanceSet, FromFactory) { @@ -107,3 +110,17 @@ TEST(HipStatusMatcher, Basic) EXPECT_THAT(hipSuccess, Not(HipError(hipErrorInvalidValue))); EXPECT_THAT(hipErrorOutOfMemory, Not(HipError(hipErrorInvalidValue))); } + +TEST(RunResultMatcher, Basic) +{ + EXPECT_THAT(ckt::RunResult::from_runtime(0), SuccessfulRun()); + EXPECT_THAT(ckt::RunResult::not_supported("test error"), Not(SuccessfulRun())); +} + +TEST(RunResultMatcher, ExplainMatchResult) +{ + testing::StringMatchResultListener listener; + EXPECT_TRUE(!ExplainMatchResult( + SuccessfulRun(), ckt::RunResult::not_supported("test error"), &listener)); + EXPECT_THAT(listener.str(), StringEqWithDiff("run failed: test error")); +} diff --git a/experimental/builder/test/testing_utils.cpp b/experimental/builder/test/testing_utils.cpp index b60c35333e..e9677e5940 100644 --- a/experimental/builder/test/testing_utils.cpp +++ b/experimental/builder/test/testing_utils.cpp @@ -339,4 +339,22 @@ void HipStatusMatcher::DescribeNegationTo(std::ostream* os) const return ::testing::MakeMatcher(new HipStatusMatcher(error)); } +bool RunResultMatcher::MatchAndExplain(builder::test::RunResult actual, + ::testing::MatchResultListener* listener) const +{ + if(actual.error.has_value() && listener) + *listener << "run failed: " << actual.error.value(); + + return actual.is_supported(); +} + +void RunResultMatcher::DescribeTo(std::ostream* os) const { *os << "successful run"; } + +void RunResultMatcher::DescribeNegationTo(std::ostream* os) const { *os << "unsuccessful run"; } + +::testing::Matcher SuccessfulRun() +{ + return ::testing::MakeMatcher(new RunResultMatcher()); +} + } // namespace ck_tile::test diff --git a/experimental/builder/test/testing_utils.hpp b/experimental/builder/test/testing_utils.hpp index b84d53b6df..55de133a2a 100644 --- a/experimental/builder/test/testing_utils.hpp +++ b/experimental/builder/test/testing_utils.hpp @@ -161,6 +161,23 @@ struct HipStatusMatcher : public ::testing::MatcherInterface /// @param error The error to expect. ::testing::Matcher HipError(hipError_t error); +/// @brief RunResult matcher +/// +/// `ckt::run` returns a RunResult which indicates whether there was any +/// problem while running the algorithm. This matcher is used to match those +/// values. +struct RunResultMatcher : public ::testing::MatcherInterface +{ + bool MatchAndExplain(builder::test::RunResult actual, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + void DescribeNegationTo(std::ostream* os) const override; +}; + +/// @brief Construct a Google Test matcher that checks that a ckt::run result +/// was successful. +::testing::Matcher SuccessfulRun(); + template struct ReferenceOutputMatcher : public ::testing::MatcherInterface> @@ -180,6 +197,21 @@ struct ReferenceOutputMatcher if(listener->IsInterested() && !errors.empty()) { *listener << errors.size() << " tensors failed to validate"; + + for(const auto& e : errors) + { + *listener << "\n - " << e.tensor_name << ": "; + + if(e.is_all_zero()) + *listener << "all elements in actual and expected tensors are zero"; + else + { + // Round to 2 digits + const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f; + *listener << e.wrong_elements << "/" << e.total_elements + << " incorrect elements (~" << percentage << "%)"; + } + } } return errors.empty(); diff --git a/experimental/builder/test/unit_conv_fwd_testing.cpp b/experimental/builder/test/unit_conv_fwd_testing.cpp index be95a29a2d..9fc07568b4 100644 --- a/experimental/builder/test/unit_conv_fwd_testing.cpp +++ b/experimental/builder/test/unit_conv_fwd_testing.cpp @@ -3,7 +3,7 @@ #include "impl/conv_signature_types.hpp" #include "testing_utils.hpp" -#include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" #include "ck_tile/builder/testing/tensor_foreach.hpp" #include #include diff --git a/experimental/builder/test/unit_validation.cpp b/experimental/builder/test/unit_validation.cpp index a83d034ac2..0dad8593fb 100644 --- a/experimental/builder/test/unit_validation.cpp +++ b/experimental/builder/test/unit_validation.cpp @@ -296,5 +296,8 @@ TEST(MatchesReference, Incorrect) testing::StringMatchResultListener listener; EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener)); - EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate")); + EXPECT_THAT(listener.str(), + StringEqWithDiff( // + "1 tensors failed to validate\n" + " - a: 625/625 incorrect elements (~100%)")); } diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc index 4b4c144428..ae451caec0 100644 --- a/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc +++ b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc @@ -1,5 +1,6 @@ #include "../../builder/test/utils/ckb_conv_tile_test_configs.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_run.inc b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc index 6b8024fa93..016ef3e653 100644 --- a/experimental/grouped_convolution_tile_instances/instances/instance_run.inc +++ b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc @@ -2,8 +2,6 @@ using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; -auto conv = Instance{}; -bool is_supported; -float avg_time; -std::tie(is_supported, avg_time) = ckt::run(conv, args, inputs, outputs, s_conf); -return std::make_tuple(is_supported, avg_time, conv.GetInstanceString()); +auto conv = Instance{}; +ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf); +return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString()); diff --git a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp index e58c884729..9f7227a699 100644 --- a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp +++ b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp @@ -9,8 +9,9 @@ #include "grouped_convolution_signatures.hpp" #include "ck_tile/builder/testing/filter_extent.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" -#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/builder/conv_builder.hpp" namespace ck_tile::builder::profiling { @@ -113,8 +114,8 @@ run_grouped_conv_forward_tile_algs(const ckt::Args& args, auto reference = ckt::alloc_outputs(args); using ReferenceInstance = typename ckb::ConvBuilder::Instance; - auto ref_conv = ReferenceInstance{}; - ckt::run(ref_conv, args, inputs, reference.get()); + auto ref_conv = ReferenceInstance{}; + [[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get()); [[maybe_unused]] auto run_alg = [&](auto&& run_alg_func) { std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf); diff --git a/profiler/include/profiler/grouped_convolution_signatures.hpp b/profiler/include/profiler/grouped_convolution_signatures.hpp index 5103b0f235..0f87e283bb 100644 --- a/profiler/include/profiler/grouped_convolution_signatures.hpp +++ b/profiler/include/profiler/grouped_convolution_signatures.hpp @@ -6,7 +6,7 @@ #include #include "../../experimental/builder/test/impl/conv_signature_types.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" namespace ck_tile::builder::profiling { diff --git a/profiler/src/profile_grouped_conv_fwd_tile.cpp b/profiler/src/profile_grouped_conv_fwd_tile.cpp index 8023dcf2f6..1a1e8b769a 100644 --- a/profiler/src/profile_grouped_conv_fwd_tile.cpp +++ b/profiler/src/profile_grouped_conv_fwd_tile.cpp @@ -6,7 +6,7 @@ #include #include -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" #include "ck_tile/host/device_prop.hpp" #include "profiler/grouped_convolution_forward_tile_algs.hpp" diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp index c04a15ec98..068811cf00 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp @@ -7,7 +7,7 @@ #include #include -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" #include "ck_tile/host/device_prop.hpp" #include "profiler/grouped_convolution_forward_tile_algs.hpp" From c190d8d61f2ea44a0d04b8c6706434098ca0c691 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Tue, 27 Jan 2026 09:49:42 +0100 Subject: [PATCH 40/42] [CK tests] Extend conv GPU reference (#3539) * test_convnd_fwd * test_convnd_bwd_data * test_conv_bwd_data_scale * test_grouped_convnd_fwd_clamp * test_grouped_convnd_fwd_scale * multiple A/B tensors and D tensor for fwd GPU ref * test_grouped_convnd_fwd_scaleadd_ab * test_grouped_convnd_fwd_bias_clamp * test_grouped_convnd_fwd_bilinear * test_grouped_convnd_fwd_gk_bias_clamp * Extend GPU reference to enable batchnorm epilogue * test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp * test_grouped_conv_bwd_data_bilinear * test_grouped_convnd_bwd_weight_bilinear * Add missing template instantiation * Perform operations in float in reference * Slightly increase tolerance for batchnorm profiler * Revert "Slightly increase tolerance for batchnorm profiler" This reverts commit a3b247522902c712930369f466c376a6430f4f67. * Revert "test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp" This reverts commit 6da4576060215e1d3e0e79ca355c340d3546363c. * Revert "Extend GPU reference to enable batchnorm epilogue" This reverts commit e2f75fa10e80740eddb7a46f0a51aaac74b8f1a5. * Clarify variable names * Refactor elementwise ops into helper functions * Make helpers C++17-compatible --- .../element/unary_element_wise_operation.hpp | 23 + .../gpu/naive_conv_bwd_data_gpu.hpp | 465 ++++++++++++----- .../gpu/naive_conv_bwd_weight_gpu.hpp | 475 ++++++++++++++---- .../gpu/naive_conv_fwd_gpu.hpp | 468 +++++++++++++---- .../gpu/naive_conv_utils.hpp | 117 ++++- .../profiler/profile_conv_bwd_data_impl.hpp | 56 ++- .../profiler/profile_conv_fwd_impl.hpp | 45 +- ...ofile_grouped_conv_fwd_bias_clamp_impl.hpp | 73 ++- ...profile_grouped_conv_fwd_bilinear_impl.hpp | 59 ++- ...ile_grouped_conv_fwd_outelementop_impl.hpp | 77 ++- test/convnd_bwd_data/convnd_bwd_data_xdl.cpp | 2 +- test/convnd_fwd/convnd_fwd_xdl.cpp | 2 +- test/gpu_reference/CMakeLists.txt | 3 + test/gpu_reference/gpu_reference_utils.hpp | 225 +++++++++ .../test_gpu_reference_conv_fwd_multi_abd.cpp | 319 ++++++++++++ .../test_grouped_conv_bwd_data_bilinear.cpp | 81 +-- .../test_grouped_conv_bwd_data_scale.cpp | 51 +- ...est_grouped_convnd_bwd_weight_bilinear.cpp | 83 +-- .../test_grouped_convnd_fwd_bilinear.cpp | 4 +- .../test_grouped_convnd_fwd_scaleadd_ab.cpp | 52 +- .../test_grouped_convnd_fwd_bias_clamp.cpp | 2 +- .../test_grouped_convnd_fwd_clamp.cpp | 2 +- .../test_grouped_convnd_fwd_gk_bias_clamp.cpp | 2 +- .../test_grouped_convnd_fwd_scale.cpp | 4 +- 24 files changed, 2217 insertions(+), 473 deletions(-) create mode 100644 test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 6cd7b3d9f6..31047c03b2 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1631,6 +1631,13 @@ struct ConvInvscale e = type_convert(c / scale_in_ / scale_wei_ / scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + e = type_convert(c_float / scale_in_ / scale_wei_ / scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; @@ -1656,6 +1663,13 @@ struct ConvScale e = type_convert(c * scale_in_ * scale_wei_ * scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + e = type_convert(c_float * scale_in_ * scale_wei_ * scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; @@ -1683,6 +1697,15 @@ struct ConvScaleRelu e = type_convert(x * scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + float x; + Relu{}.template operator()(x, c_float * scale_in_ * scale_wei_); + e = type_convert(x * scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp index aecf519c10..5210265cef 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp @@ -10,49 +10,55 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized backward data convolution kernel working with packed (contiguous) tensors -// Computes gradients w.r.t. input from output gradients and weights -// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], -// output[G][N][K][spatial] +// Optimized backward data convolution kernel working with packed (contiguous) tensors with +// multi-ABD support Computes gradients w.r.t. input from output gradients and weights Assumes +// row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], output[G][N][K][spatial] template -__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, - const WeiDataType* __restrict__ p_wei, - const OutDataType* __restrict__ p_out, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in, + const WeiDataType* const* __restrict__ p_weis, + const OutDataType* const* __restrict__ p_outs, + const DDataType* const* __restrict__ p_ds, + const index_t* const* __restrict__ p_d_strides, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -84,9 +90,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t x = 0; x < X; ++x) { @@ -96,21 +103,39 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t wo = w_tmp / stride_x; if(wo >= 0 && wo < Wo) { - const OutDataType* out_gnk = out_gn; - const WeiDataType* wei_gkc = wei_g + c * wei_stride_c; + // Pointers at current filter position + const OutDataType* output_grad_g_n_k = output_grad_g_n; + const WeiDataType* weight_g_k_c = weight_g + c * wei_stride_c; for(index_t k = 0; k < K; ++k) { - out_op(out_val, out_gnk[k * out_stride_k + wo]); - wei_op(wei_val, wei_gkc[k * wei_stride_k + x]); + // Handle output gradient element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_g_n_k, + p_outs + 1, + g * out_stride_g + n * out_stride_n, + k * out_stride_k + wo); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_g_k_c, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } } } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op( + in_val, in_op, acc, p_ds, p_d_strides, g, n, c, wi); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val; } } @@ -142,9 +167,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t y = 0; y < Y; ++y) { @@ -154,8 +180,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - const OutDataType* out_gnkh = out_gn + ho * out_stride_h; - const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y; + // Pointers at current spatial height and filter Y position + const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_stride_h; + const WeiDataType* weight_at_c_y = + weight_g + c * wei_stride_c + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { @@ -167,8 +195,25 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, { for(index_t k = 0; k < K; ++k) { - out_op(out_val, out_gnkh[k * out_stride_k + wo]); - wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]); + // Handle output gradient element-wise operation with extra + // A tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_h, + p_outs + 1, + g * out_stride_g + n * out_stride_n + ho * out_stride_h, + k * out_stride_k + wo); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_c_y, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c + y * wei_stride_y, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } @@ -179,8 +224,17 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op(in_val, + in_op, + acc, + p_ds, + p_d_strides, + g, + n, + c, + hi * p_d_strides[0][3] + + wi * p_d_strides[0][4]); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] = in_val; } @@ -218,9 +272,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t z = 0; z < Z; ++z) { @@ -230,8 +285,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t do_idx = d_tmp / stride_z; if(do_idx >= 0 && do_idx < Do) { - const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d; - const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z; + // Pointers at current spatial depth + const OutDataType* output_grad_at_d = + output_grad_g_n + do_idx * out_stride_d; + const WeiDataType* weight_at_c_z = + weight_g + c * wei_stride_c + z * wei_stride_z; for(index_t y = 0; y < Y; ++y) { @@ -241,8 +299,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h; - const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y; + // Pointers at current spatial depth and height + const OutDataType* output_grad_at_d_h = + output_grad_at_d + ho * out_stride_h; + const WeiDataType* weight_at_c_z_y = + weight_at_c_z + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { @@ -254,10 +315,31 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, { for(index_t k = 0; k < K; ++k) { - out_op(out_val, - out_gnkdh[k * out_stride_k + wo]); - wei_op(wei_val, - wei_gkczy[k * wei_stride_k + x]); + // Handle output gradient element-wise operation + // with extra A tensors + detail::apply_multi_tensor_elementwise_op< + NumAExtra>(out_val, + out_op, + output_grad_at_d_h, + p_outs + 1, + g * out_stride_g + + n * out_stride_n + + do_idx * out_stride_d + + ho * out_stride_h, + k * out_stride_k + wo); + + // Handle weight element-wise operation with + // extra B tensors + detail::apply_multi_tensor_elementwise_op< + NumBExtra>( + wei_val, + wei_op, + weight_at_c_z_y, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c + + z * wei_stride_z + y * wei_stride_y, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } @@ -271,16 +353,28 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op( + in_val, + in_op, + acc, + p_ds, + p_d_strides, + g, + n, + c, + di * p_d_strides[0][3] + hi * p_d_strides[0][4] + wi * p_d_strides[0][5]); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d + hi * in_stride_h + wi] = in_val; } } } -// GPU reference backward data convolution - takes ConvParam directly -template -void naive_conv_bwd_data(TIn* p_in, - const TWei* p_wei, - const TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TIn> // D tensor type, defaults to TIn for backward compatibility +void naive_conv_bwd_data_multi_abd( + TIn* p_in, + const std::array& p_weis, + const std::array& p_outs, + const std::array& p_ds, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -327,12 +426,34 @@ void naive_conv_bwd_data(TIn* p_in, // Allocate packed buffers SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei)); - SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut)); - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); - TWei* p_wei_packed = static_cast(wei_packed_buf.GetDeviceBuffer()); - TOut* p_out_packed = static_cast(out_packed_buf.GetDeviceBuffer()); + std::vector wei_packed_bufs; + wei_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); + } + + std::vector out_packed_bufs; + out_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + out_packed_bufs.emplace_back(out_total * sizeof(TOut)); + } + + TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); + + std::array p_weis_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); + } + + std::array p_outs_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_outs_packed[i] = static_cast(out_packed_bufs[i].GetDeviceBuffer()); + } // Compute strides and allocate device arrays for pack/unpack std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); @@ -369,12 +490,76 @@ void naive_conv_bwd_data(TIn* p_in, // Pack output and weight tensors to contiguous layout (inputs to bwd data) constexpr int block_size = 256; - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total); - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total); + + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_outs[i], p_outs_packed[i], d_out_lengths, d_out_strides, dim_count, out_total); + } + + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); + } + + // Prepare D tensor stride arrays on device + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); + SimpleDeviceMem outs_ptrs_buf((NumAElementwise + 1) * sizeof(TOut*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TOut** d_outs_ptrs = static_cast(outs_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, + p_weis_packed.data(), + (NumBElementwise + 1) * sizeof(TWei*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_outs_ptrs, + p_outs_packed.data(), + (NumAElementwise + 1) * sizeof(TOut*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -392,16 +577,22 @@ void naive_conv_bwd_data(TIn* p_in, if(ndim == 1) { - naive_conv_bwd_data_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -430,16 +621,22 @@ void naive_conv_bwd_data(TIn* p_in, } else if(ndim == 2) { - naive_conv_bwd_data_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -468,16 +665,22 @@ void naive_conv_bwd_data(TIn* p_in, } else // 3D { - naive_conv_bwd_data_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -514,5 +717,43 @@ void naive_conv_bwd_data(TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_bwd_data - now a zero-overhead wrapper +template +inline void naive_conv_bwd_data(TIn* p_in, + const TWei* p_wei, + const TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_weis = {p_wei}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_bwd_data_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_in, + p_weis, + p_outs, + p_ds, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp index f46b072baa..8cee2e2b77 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp @@ -10,49 +10,58 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized backward weight convolution kernel working with packed (contiguous) tensors +// Optimized backward weight convolution kernel working with packed (contiguous) tensors with +// multi-ABD support // Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial], // weight_grad[G][K][C][filter] // Computes gradient with respect to weights template -__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in, - WeiDataType* __restrict__ p_wei_grad, - const OutDataType* __restrict__ p_out_grad, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void +naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_ins, + WeiDataType* __restrict__ p_wei_grad, + const OutDataType* const* __restrict__ p_out_grads, + const DDataType* const* __restrict__ p_ds, + const index_t* const* __restrict__ p_d_strides, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -84,30 +93,50 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gn[wi]); - out_op(out_val, out_gn_k[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_n_c, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c, + wi); + + // Handle output gradient element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_n_k, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k, + wo); + acc += type_convert(out_val) * type_convert(in_val); } } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op( + wei_val, wei_op, acc, p_ds, p_d_strides, g, k, c, x); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val; } } @@ -139,31 +168,55 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t ho = 0; ho < Ho; ++ho) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gnch = in_gnc + hi * in_stride_h; - const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h; + // Pointers at current spatial height + const InDataType* input_at_h = input_at_n_c + hi * in_stride_h; + const OutDataType* output_grad_at_h = + output_grad_at_n_k + ho * out_stride_h; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gnch[wi]); - out_op(out_val, out_gn_kh[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + hi * in_stride_h, + wi); + + // Handle output gradient element-wise operation with extra B + // tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_h, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k + + ho * out_stride_h, + wo); + acc += type_convert(out_val) * type_convert(in_val); } } @@ -171,8 +224,17 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op(wei_val, + wei_op, + acc, + p_ds, + p_d_strides, + g, + k, + c, + y * p_d_strides[0][3] + + x * p_d_strides[0][4]); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y + x] = wei_val; } @@ -210,39 +272,65 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t do_idx = 0; do_idx < Do; ++do_idx) { long_index_t di = do_idx * stride_z + z * dilation_z - pad_z; if(di >= 0 && di < Di) { - const InDataType* in_gncd = in_gnc + di * in_stride_d; - const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d; + // Pointers at current spatial depth + const InDataType* input_at_d = input_at_n_c + di * in_stride_d; + const OutDataType* output_grad_at_d = + output_grad_at_n_k + do_idx * out_stride_d; for(index_t ho = 0; ho < Ho; ++ho) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gncdh = in_gncd + hi * in_stride_h; - const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h; + // Pointers at current spatial depth and height + const InDataType* input_at_d_h = input_at_d + hi * in_stride_h; + const OutDataType* output_grad_at_d_h = + output_grad_at_d + ho * out_stride_h; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gncdh[wi]); - out_op(out_val, out_gn_kdh[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_d_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + di * in_stride_d + hi * in_stride_h, + wi); + + // Handle output gradient element-wise operation with extra + // B tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_d_h, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k + + do_idx * out_stride_d + ho * out_stride_h, + wo); + acc += type_convert(out_val) * type_convert(in_val); } @@ -253,16 +341,28 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op( + wei_val, + wei_op, + acc, + p_ds, + p_d_strides, + g, + k, + c, + z * p_d_strides[0][3] + y * p_d_strides[0][4] + x * p_d_strides[0][5]); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z + y * wei_stride_y + x] = wei_val; } } } -// GPU reference backward weight convolution - takes ConvParam directly -template -void naive_conv_bwd_weight(const TIn* p_in, - TWei* p_wei_grad, - const TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TWei> // D tensor type, defaults to TWei for backward compatibility +void naive_conv_bwd_weight_multi_abd( + const std::array& p_ins, + TWei* p_wei_grad, + const std::array& p_outs, + const std::array& p_ds, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -308,13 +413,35 @@ void naive_conv_bwd_weight(const TIn* p_in, out_total *= l; // Allocate packed buffers - SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei)); - SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut)); + std::vector in_packed_bufs; + in_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + in_packed_bufs.emplace_back(in_total * sizeof(TIn)); + } + + SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei)); + + std::vector out_grad_packed_bufs; + out_grad_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + out_grad_packed_bufs.emplace_back(out_total * sizeof(TOut)); + } + + std::array p_ins_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); + } - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); TWei* p_wei_grad_packed = static_cast(wei_grad_packed_buf.GetDeviceBuffer()); - TOut* p_out_grad_packed = static_cast(out_grad_packed_buf.GetDeviceBuffer()); + + std::array p_out_grads_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_out_grads_packed[i] = static_cast(out_grad_packed_bufs[i].GetDeviceBuffer()); + } // Compute strides and allocate device arrays for pack/unpack std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); @@ -351,12 +478,81 @@ void naive_conv_bwd_weight(const TIn* p_in, // Pack input and output_grad tensors to contiguous layout (inputs to bwd weight) constexpr int block_size = 256; - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total); - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total); + + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); + } + + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_outs[i], + p_out_grads_packed[i], + d_out_lengths, + d_out_strides, + dim_count, + out_total); + } + + // Prepare D tensor stride arrays on device + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*)); + SimpleDeviceMem out_grads_ptrs_buf((NumBElementwise + 1) * sizeof(TOut*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TOut** d_out_grads_ptrs = static_cast(out_grads_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, + p_ins_packed.data(), + (NumAElementwise + 1) * sizeof(TIn*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_out_grads_ptrs, + p_out_grads_packed.data(), + (NumBElementwise + 1) * sizeof(TOut*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -374,16 +570,22 @@ void naive_conv_bwd_weight(const TIn* p_in, if(ndim == 1) { - naive_conv_bwd_weight_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -412,16 +614,22 @@ void naive_conv_bwd_weight(const TIn* p_in, } else if(ndim == 2) { - naive_conv_bwd_weight_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -450,16 +658,22 @@ void naive_conv_bwd_weight(const TIn* p_in, } else // 3D { - naive_conv_bwd_weight_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -496,5 +710,44 @@ void naive_conv_bwd_weight(const TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_bwd_weight - now a zero-overhead wrapper +template +inline void +naive_conv_bwd_weight(const TIn* p_in, + TWei* p_wei_grad, + const TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_ins = {p_in}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_bwd_weight_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, + p_wei_grad, + p_outs, + p_ds, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp index 131b632a25..7bf9b49998 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp @@ -10,48 +10,56 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized convolution kernel working with packed (contiguous) tensors +// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support // Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], // output[G][N][K][spatial] template -__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, - const WeiDataType* __restrict__ p_wei, - OutDataType* __restrict__ p_out, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void naive_conv_fwd_packed_multi_abd( + const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra) + const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra) + const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers + const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays + OutDataType* __restrict__ p_out, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -83,29 +91,48 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gc = in_g + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gc[wi]); - wei_op(wei_val, wei_gkc[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_c, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_c, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c, + x); + acc += type_convert(in_val) * type_convert(wei_val); } } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op( + out_val, out_op, acc, p_ds, p_d_strides, g, n, k, wo); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val; } } @@ -137,30 +164,51 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gnc = in_gn + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gnch = in_gnc + hi * in_stride_h; - const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y; + // Pointers at current spatial height and filter Y position + const InDataType* input_at_h = input_at_c + hi * in_stride_h; + const WeiDataType* weight_at_y = weight_at_c + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gnch[wi]); - wei_op(wei_val, wei_gkcy[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + hi * in_stride_h, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_y, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + + y * wei_stride_y, + x); + acc += type_convert(in_val) * type_convert(wei_val); } } @@ -168,8 +216,17 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op(out_val, + out_op, + acc, + p_ds, + p_d_strides, + g, + n, + k, + ho * p_d_strides[0][3] + + wo * p_d_strides[0][4]); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] = out_val; } @@ -207,38 +264,60 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gnc = in_gn + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t z = 0; z < Z; ++z) { long_index_t di = do_idx * stride_z + z * dilation_z - pad_z; if(di >= 0 && di < Di) { - const InDataType* in_gncd = in_gnc + di * in_stride_d; - const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z; + // Pointers at current spatial depth + const InDataType* input_at_d = input_at_c + di * in_stride_d; + const WeiDataType* weight_at_z = weight_at_c + z * wei_stride_z; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gncdh = in_gncd + hi * in_stride_h; - const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y; + // Pointers at current spatial depth and height + const InDataType* input_at_d_h = input_at_d + hi * in_stride_h; + const WeiDataType* weight_at_z_y = weight_at_z + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gncdh[wi]); - wei_op(wei_val, wei_gkczy[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_d_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + di * in_stride_d + hi * in_stride_h, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_z_y, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + + z * wei_stride_z + y * wei_stride_y, + x); + acc += type_convert(in_val) * type_convert(wei_val); } @@ -249,16 +328,28 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op( + out_val, + out_op, + acc, + p_ds, + p_d_strides, + g, + n, + k, + do_idx * p_d_strides[0][3] + ho * p_d_strides[0][4] + wo * p_d_strides[0][5]); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d + ho * out_stride_h + wo] = out_val; } } } -// GPU reference convolution - takes ConvParam directly -template -void naive_conv_fwd(const TIn* p_in, - const TWei* p_wei, - TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TOut> // D tensor type, defaults to TOut for backward compatibility +void naive_conv_fwd_multi_abd( + const std::array& p_ins, + const std::array& p_weis, + const std::array& p_ds, + TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -303,13 +399,37 @@ void naive_conv_fwd(const TIn* p_in, for(auto l : out_lengths) out_total *= l; - // Allocate packed buffers - SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei)); + // Allocate packed buffers for all A and B tensors + // Use separate allocations to avoid copy assignment issues with RAII wrapper + std::vector in_packed_bufs; + in_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + in_packed_bufs.emplace_back(in_total * sizeof(TIn)); + } + + std::vector wei_packed_bufs; + wei_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); + } + SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut)); - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); - TWei* p_wei_packed = static_cast(wei_packed_buf.GetDeviceBuffer()); + // Get packed buffer pointers + std::array p_ins_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); + } + + std::array p_weis_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); + } + TOut* p_out_packed = static_cast(out_packed_buf.GetDeviceBuffer()); // Compute strides and allocate device arrays for pack/unpack @@ -347,12 +467,82 @@ void naive_conv_fwd(const TIn* p_in, // Pack input and weight tensors to contiguous layout constexpr int block_size = 256; - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total); - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total); + + // Pack all A tensors + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); + } + + // Pack all B tensors + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); + } + + // Prepare D tensor stride arrays on device + // NOTE: D tensors are NOT packed - they are used directly with their original strides + // to support broadcasting (e.g., BiasGK layout with zero strides) + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + // Allocate and copy strides to device + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*)); + SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, + p_ins_packed.data(), + (NumAElementwise + 1) * sizeof(TIn*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, + p_weis_packed.data(), + (NumBElementwise + 1) * sizeof(TWei*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + // D tensors use original pointers (not packed) to support broadcasting + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -370,15 +560,21 @@ void naive_conv_fwd(const TIn* p_in, if(ndim == 1) { - naive_conv_fwd_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -408,15 +604,21 @@ void naive_conv_fwd(const TIn* p_in, } else if(ndim == 2) { - naive_conv_fwd_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -446,15 +648,21 @@ void naive_conv_fwd(const TIn* p_in, } else // 3D { - naive_conv_fwd_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -492,5 +700,43 @@ void naive_conv_fwd(const TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_fwd - now a zero-overhead wrapper +template +inline void naive_conv_fwd(const TIn* p_in, + const TWei* p_wei, + TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_ins = {p_in}; + std::array p_weis = {p_wei}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, + p_weis, + p_ds, + p_out, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp index 0a7b58b310..50b65357a2 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp @@ -22,9 +22,39 @@ struct SimpleDeviceMem HIP_CHECK_ERROR(hipMalloc(static_cast(&p_mem_), mem_size)); } + // Delete copy operations (resource should not be copied) + SimpleDeviceMem(const SimpleDeviceMem&) = delete; + SimpleDeviceMem& operator=(const SimpleDeviceMem&) = delete; + + // Define move operations + SimpleDeviceMem(SimpleDeviceMem&& other) noexcept : p_mem_(other.p_mem_) + { + other.p_mem_ = nullptr; + } + + SimpleDeviceMem& operator=(SimpleDeviceMem&& other) noexcept + { + if(this != &other) + { + if(p_mem_) + { + (void)hipFree(p_mem_); + } + p_mem_ = other.p_mem_; + other.p_mem_ = nullptr; + } + return *this; + } + void* GetDeviceBuffer() { return p_mem_; } - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + ~SimpleDeviceMem() + { + if(p_mem_) + { + (void)hipFree(p_mem_); + } + } void* p_mem_; }; @@ -173,5 +203,90 @@ __global__ void strided_copy_kernel(const DataType* __restrict__ src, } } +namespace detail { + +// Helper for parameter pack expansion (D tensors) +template +__device__ __forceinline__ void apply_multi_tensor_impl(ResultType& result, + Op&& element_op, + const DataType* const* tensor_ptrs, + long_index_t element_offset, + std::index_sequence) +{ + element_op(result, tensor_ptrs[Is][element_offset]...); +} + +// Generic helper for A and B tensors (works in all directions) +template +__device__ __forceinline__ void apply_multi_tensor_elementwise_op(ResultType& result, + Op&& element_op, + const DataType* primary_ptr, + const DataType* const* extra_ptrs, + long_index_t extra_base_offset, + long_index_t element_offset) +{ + const DataType* tensor_ptrs[NumExtraTensors + 1]; + tensor_ptrs[0] = primary_ptr; + + static_for<1, NumExtraTensors + 1, 1>{}( + [&](auto i) { tensor_ptrs[i] = extra_ptrs[i - 1] + extra_base_offset; }); + + apply_multi_tensor_impl(result, + element_op, + tensor_ptrs, + element_offset, + std::make_index_sequence{}); +} + +// Helper for parameter pack expansion (D tensors) +template +__device__ __forceinline__ void apply_d_tensor_impl(OutDataType& result_out, + Op&& element_op, + float computed_value, + const float* d_values, + std::index_sequence) +{ + float temp_out; + element_op(temp_out, computed_value, d_values[Is]...); + result_out = type_convert(temp_out); +} + +// Specialized helper for D tensors with stride calculations and float conversion +template +__device__ __forceinline__ void apply_d_tensor_elementwise_op(OutDataType& result_out, + Op&& element_op, + float computed_value, + const DDataType* const* p_ds, + const index_t* const* p_d_strides, + index_t g, + index_t n, + index_t c_or_k, + long_index_t spatial_linear_index) +{ + if constexpr(NumDTensors == 0) + { + element_op(result_out, computed_value); + } + else + { + float d_values[NumDTensors]; + + // Compute all D tensor indices and convert to float + static_for<0, NumDTensors, 1>{}([&](auto i) { + const long_index_t d_idx = g * p_d_strides[i][0] + n * p_d_strides[i][1] + + c_or_k * p_d_strides[i][2] + spatial_linear_index; + d_values[i] = type_convert(p_ds[i][d_idx]); + }); + + apply_d_tensor_impl(result_out, + element_op, + computed_value, + d_values, + std::make_index_sequence{}); + } +} + +} // namespace detail + } // namespace ref } // namespace ck diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index a0f9b9ac25..bf5ffcb5d2 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -17,6 +17,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" namespace ck { namespace profiler { @@ -129,7 +130,10 @@ bool profile_conv_bwd_data_impl(int do_verification, out_device_buf.ToDevice(output.mData.data()); wei_device_buf.ToDevice(weight.mData.data()); - if(do_verification) + // profile device Conv instances + bool pass = true; + + if(do_verification == 1) { auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData gpu_ref_input(in_g_n_c_wis_desc); + if(do_verification == 2) + { + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * + input_device_result.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization + + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(gpu_ref_input.mData.data()); + } + using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData gpu_ref_output(out_g_n_k_wos_desc); + if(do_verification == 2) + { + DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_out_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + hip_check_error(hipDeviceSynchronize()); + gpu_ref_out_dev.FromDevice(gpu_ref_output.mData.data()); + } using DeviceOp = ck::tensor_operation::device::DeviceConvFwd(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "gpu_ref_output : ", gpu_ref_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } } else { diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index 50cd58eec3..2a282edbc8 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -21,6 +21,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" namespace ck { namespace profiler { @@ -156,8 +157,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, bias_device_buf.ToDevice(bias.mData.data()); // run reference op - if(do_verification) + if(do_verification == 1) { + // CPU reference auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); + + d_lengths_vec[0] = conv_param.G_; + d_lengths_vec[1] = conv_param.N_; + d_lengths_vec[2] = conv_param.K_; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + } + + if constexpr(BiasGK) + { + // For GK bias layout: G*K, zero strides for N and spatial dimensions + d_strides_vec[0] = K; + d_strides_vec[1] = 0; + d_strides_vec[2] = 1; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_strides_vec[3 + i] = 0; + } + } + else + { + // Full GNKHW layout - same as output + ck::ranges::copy(out_g_n_k_wos_desc.GetStrides(), d_strides_vec.begin()); + } + + std::array d_ptrs = { + reinterpret_cast(bias_device_buf.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer())}; + + ck::ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + OutDataType>( // Explicitly specify TD = OutDataType + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); + } std::string best_op_name; float best_avg_time = 0; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp index 3f4905c110..b439428cda 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp @@ -22,6 +22,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" namespace ck { namespace profiler { @@ -129,8 +130,9 @@ bool profile_grouped_conv_fwd_bilinear_impl( wei_device_buf.ToDevice(weight.mData.data()); d_device_buf.ToDevice(d_tensor.mData.data()); - if(do_verification) + if(do_verification == 1) { + // CPU reference auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< NDimSpatial, InDataType, @@ -167,6 +169,61 @@ bool profile_grouped_conv_fwd_bilinear_impl( host_output(idx) = ck::type_convert(out_val); }); } + else if(do_verification == 2) + { + // GPU reference + std::vector d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); + + d_lengths_vec[0] = conv_param.G_; + d_lengths_vec[1] = conv_param.N_; + d_lengths_vec[2] = conv_param.K_; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + } + + // D tensor has same layout as output + ck::ranges::copy(d_host_tensor_descriptor.GetStrides(), d_strides_vec.begin()); + + std::array d_ptrs = { + reinterpret_cast(d_device_buf.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer())}; + + ck::ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DDataType>( // Explicitly specify D tensor type + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + InElementOp{}, + WeiElementOp{}, + bilinear_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); + } std::string best_op_name; float best_avg_time = 0; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index acdc937a33..9444996c25 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -7,6 +7,7 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "profiler/common.hpp" @@ -150,7 +151,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, std::cout << "scale_out: " << scale_out << std::endl; // run reference op - if(do_verification) + if(do_verification == 1) { std::cout << "\nVerifying algorithm against reference convolution..." << std::endl; @@ -200,6 +201,57 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, } }); } + else if(do_verification == 2) + { + // GPU reference + // WORKAROUND: For int8_t with Scale, use CPU post-processing to match CPU reference + // Pure GPU approach fails int8 test (see 2026-01-07-int8-scale-debugging.md) + if constexpr(std::is_same_v && + std::is_same_v) + { + // Compute conv to CShuffleDataType (float), then post-process on CPU + DeviceMem gpu_ref_c_dev(sizeof(CShuffleDataType) * c.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_c_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + PassThrough{}); + + ck::hip_check_error(hipDeviceSynchronize()); + + Tensor gpu_c(out_g_n_k_wos_desc); + gpu_ref_c_dev.FromDevice(gpu_c.mData.data()); + + // Post-process on CPU to match CPU reference behavior + host_output.ForEach([&](auto&, auto idx) { + const auto conv_shuffle = ck::type_convert(gpu_c(idx)); + const auto conv_val = ck::type_convert(conv_shuffle); + out_element_op(host_output(idx), conv_val); + }); + } + else + { + // Normal path for non-int8 or non-Scale cases + DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * + device_output.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_out_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + ck::hip_check_error(hipDeviceSynchronize()); + gpu_ref_out_dev.FromDevice(host_output.mData.data()); + } + } std::string best_op_name; float best_avg_time = 0; @@ -239,7 +291,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, best_gb_per_sec = gb_per_sec; } - if(do_verification) + if(do_verification == 1) { out_device_buf.FromDevice(device_output.mData.data()); @@ -259,6 +311,27 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, << std::endl; } } + else if(do_verification == 2) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = + pass & ck::utils::check_err(device_output, + host_output, + "Error: Device and GPU ref results do not match!", + get_rtol(), + get_atol()); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "gpu_ref_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } } else { diff --git a/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp index 98f466a2b3..3e4eb07a64 100644 --- a/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp +++ b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp @@ -46,7 +46,7 @@ class TestConvndBwdData : public ::testing::Test ck::tensor_layout::convolution::NDHWK>>, DataType, DataType, - DataType>(true, // do_verification + DataType>(2, // do_verification: 2 = GPU reference 1, // init_method integer value false, // do_log false, // time_kernel diff --git a/test/convnd_fwd/convnd_fwd_xdl.cpp b/test/convnd_fwd/convnd_fwd_xdl.cpp index a2fdcaf870..0377b01bb2 100644 --- a/test/convnd_fwd/convnd_fwd_xdl.cpp +++ b/test/convnd_fwd/convnd_fwd_xdl.cpp @@ -47,7 +47,7 @@ class TestConvndFwd : public ::testing::Test ck::tensor_layout::convolution::NDHWK>>, DataType, DataType, - DataType>(true, // do_verification + DataType>(2, // do_verification: 2 = GPU reference 1, // init_method integer value false, // do_log false, // time_kernel diff --git a/test/gpu_reference/CMakeLists.txt b/test/gpu_reference/CMakeLists.txt index 443818feb3..d1c3908849 100644 --- a/test/gpu_reference/CMakeLists.txt +++ b/test/gpu_reference/CMakeLists.txt @@ -4,6 +4,9 @@ add_gtest_executable(test_gpu_reference_conv_fwd test_gpu_reference_conv_fwd.cpp) target_link_libraries(test_gpu_reference_conv_fwd PRIVATE utility) +add_gtest_executable(test_gpu_reference_conv_fwd_multi_abd test_gpu_reference_conv_fwd_multi_abd.cpp) +target_link_libraries(test_gpu_reference_conv_fwd_multi_abd PRIVATE utility) + add_gtest_executable(test_gpu_reference_conv_bwd_data test_gpu_reference_conv_bwd_data.cpp) target_link_libraries(test_gpu_reference_conv_bwd_data PRIVATE utility) diff --git a/test/gpu_reference/gpu_reference_utils.hpp b/test/gpu_reference/gpu_reference_utils.hpp index fc017c8734..88306d51a4 100644 --- a/test/gpu_reference/gpu_reference_utils.hpp +++ b/test/gpu_reference/gpu_reference_utils.hpp @@ -381,5 +381,230 @@ bool test_conv_gpu_ref(const ck::utils::conv::ConvParam& params, ConvKernelType } } +// Forward convolution with D tensor support +template +bool test_conv_fwd_with_d_tensor_impl(const ck::utils::conv::ConvParam& params, + const Tensor& input_cpu, + const Tensor& weight_cpu, + const Tensor& d_cpu, + DeviceMem& input_dev, + DeviceMem& weight_dev, + DeviceMem& d_dev, + DeviceMem& output_dev, + OutElementOp out_element_op) +{ + using InElementOp = tensor_operation::element_wise::PassThrough; + using WeiElementOp = tensor_operation::element_wise::PassThrough; + + // Create D tensor lengths and strides for GPU reference + std::vector d_lengths_vec(NDimSpatial + 3); + d_lengths_vec[0] = params.G_; + d_lengths_vec[1] = params.N_; + d_lengths_vec[2] = params.K_; + for(index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(params.output_spatial_lengths_[i]); + } + + std::vector d_strides_vec = + ref::compute_conv_tensor_strides(d_lengths_vec, params.num_dim_spatial_); + + std::array d_ptrs = { + reinterpret_cast(d_dev.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + // Call GPU reference with D tensor + std::array in_ptrs = { + reinterpret_cast(input_dev.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(weight_dev.GetDeviceBuffer())}; + + ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + OutDataType>( // Explicitly specify TD = OutDataType + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(output_dev.GetDeviceBuffer()), + params, + d_lengths, + d_strides, + InElementOp{}, + WeiElementOp{}, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Run CPU reference + std::vector strides_long(params.conv_filter_strides_.begin(), + params.conv_filter_strides_.end()); + std::vector dilations_long(params.conv_filter_dilations_.begin(), + params.conv_filter_dilations_.end()); + std::vector pads_long(params.input_left_pads_.begin(), + params.input_left_pads_.end()); + + Tensor input_ref = input_cpu; + Tensor weight_ref = weight_cpu; + Tensor output_ref( + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params)); + + std::array, 1> d_tensors_ref = {d_cpu}; + + auto ref_conv = tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_arg = ref_conv.MakeArgument(input_ref, + weight_ref, + output_ref, + strides_long, + dilations_long, + pads_long, + pads_long, + InElementOp{}, + WeiElementOp{}, + out_element_op, + {}, // A tensors + {}, // B tensors + d_tensors_ref); + ref_invoker.Run(ref_arg); + + // Copy result from device and compare + Tensor output_gpu(output_ref.mDesc); + output_dev.FromDevice(output_gpu.mData.data()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Compare results + return ck::utils::check_err(output_gpu, output_ref); +} + +// Forward convolution with multiple A/B tensor support +template +bool test_conv_fwd_with_multi_ab_impl(const ck::utils::conv::ConvParam& params, + const Tensor& input_cpu, + const Tensor& weight_cpu, + const Tensor& a_extra_cpu, + const Tensor& b_extra_cpu, + DeviceMem& input_dev, + DeviceMem& weight_dev, + DeviceMem& a_extra_dev, + DeviceMem& b_extra_dev, + DeviceMem& output_dev, + InElementOp in_element_op, + WeiElementOp wei_element_op) +{ + using OutElementOp = tensor_operation::element_wise::PassThrough; + + // Call GPU reference with extra A and B tensors + std::array in_ptrs = { + reinterpret_cast(input_dev.GetDeviceBuffer()), + reinterpret_cast(a_extra_dev.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(weight_dev.GetDeviceBuffer()), + reinterpret_cast(b_extra_dev.GetDeviceBuffer())}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(output_dev.GetDeviceBuffer()), + params, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + OutElementOp{}); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Run CPU reference + std::vector strides_long(params.conv_filter_strides_.begin(), + params.conv_filter_strides_.end()); + std::vector dilations_long(params.conv_filter_dilations_.begin(), + params.conv_filter_dilations_.end()); + std::vector pads_long(params.input_left_pads_.begin(), + params.input_left_pads_.end()); + + Tensor input_ref = input_cpu; + Tensor weight_ref = weight_cpu; + Tensor output_ref( + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params)); + + std::array, 1> a_tensors_ref = {a_extra_cpu}; + std::array, 1> b_tensors_ref = {b_extra_cpu}; + + auto ref_conv = tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_arg = ref_conv.MakeArgument(input_ref, + weight_ref, + output_ref, + strides_long, + dilations_long, + pads_long, + pads_long, + in_element_op, + wei_element_op, + OutElementOp{}, + a_tensors_ref, + b_tensors_ref, + {}); + ref_invoker.Run(ref_arg); + + // Copy result from device and compare + Tensor output_gpu(output_ref.mDesc); + output_dev.FromDevice(output_gpu.mData.data()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Compare results + return ck::utils::check_err(output_gpu, output_ref); +} + } // namespace test } // namespace ck diff --git a/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp b/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp new file mode 100644 index 0000000000..ebe1e9695c --- /dev/null +++ b/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp @@ -0,0 +1,319 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "gpu_reference_utils.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +using namespace ck; +using ck::test::ConvKernelType; + +// ==================== D Tensor (Bias) Tests ==================== + +template +bool test_conv_gpu_ref_with_bias(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::AddClamp; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor bias(out_g_n_k_wos_desc); // Same shape as output + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem bias_dev(bias.mData.size() * sizeof(OutDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(bias, bias_dev); + + // Test with AddClamp (bias operation with clamping) + AddClamp out_element_op(0.0f, 6.0f); // Clamp between 0 and 6 + + return test::test_conv_fwd_with_d_tensor_impl( + params, input, weight, bias, input_dev, weight_dev, bias_dev, output_dev, out_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bias) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_bias<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bias) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_bias<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv3DFP32Bias) +{ + auto params = test::conv_test_shapes::get_3d_small(); + bool result = test_conv_gpu_ref_with_bias<3, + float, + float, + float, + tensor_layout::convolution::GNCDHW, + tensor_layout::convolution::GKCZYX, + tensor_layout::convolution::GNKDHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bias) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_bias<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32GroupedG4Bias) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g4(); + bool result = test_conv_gpu_ref_with_bias<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +// ==================== D Tensor (Bilinear) Tests ==================== + +template +bool test_conv_gpu_ref_with_bilinear(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::Bilinear; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor d_tensor(out_g_n_k_wos_desc); // Same shape as output + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem d_dev(d_tensor.mData.size() * sizeof(OutDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(d_tensor, d_dev); + + // Test with Bilinear: y = alpha * conv_result + beta * d_tensor + Bilinear out_element_op(1.5f, 0.5f); // alpha=1.5, beta=0.5 + + return test::test_conv_fwd_with_d_tensor_impl( + params, input, weight, d_tensor, input_dev, weight_dev, d_dev, output_dev, out_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_bilinear<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_bilinear<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_bilinear<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +// ==================== Multiple A/B (ScaleAdd) Tests ==================== + +template +bool test_conv_gpu_ref_with_scaleadd(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::ScaleAdd; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor a_extra(in_g_n_c_wis_desc); // Extra A tensor (same shape as input) + Tensor b_extra(wei_g_k_c_xs_desc); // Extra B tensor (same shape as weight) + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem a_extra_dev(a_extra.mData.size() * sizeof(InDataType)); + DeviceMem b_extra_dev(b_extra.mData.size() * sizeof(WeiDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(a_extra, a_extra_dev); + test::initialize_and_copy_tensor(b_extra, b_extra_dev); + + // Test with ScaleAdd: in_out = scale * in_0 + in_1, wei_out = scale * wei_0 + wei_1 + ScaleAdd in_element_op(2.0f); // scale factor for input + ScaleAdd wei_element_op(1.5f); // scale factor for weight + + return test::test_conv_fwd_with_multi_ab_impl(params, + input, + weight, + a_extra, + b_extra, + input_dev, + weight_dev, + a_extra_dev, + b_extra_dev, + output_dev, + in_element_op, + wei_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp index b45f204b40..ea7289d6bf 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp @@ -21,7 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -63,37 +63,62 @@ class TestGroupedConvndBwdData : public ::testing::Test Tensor& out, Tensor& d) { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); - std::array, NumDs> d_tensors = {d}; - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdData(); + // Prepare D tensor with correct strides for GPU kernel + std::vector d_lengths; + std::vector d_strides; + auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { + const auto& l = desc.GetLengths(); + const auto& s = desc.GetStrides(); + lengths.assign(l.begin(), l.end()); + strides.assign(s.begin(), s.end()); + }; + copy_dims(in_g_n_c_wis_desc, d_lengths, d_strides); - auto ref_invoker = ref_conv.MakeInvoker(); + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; - auto ref_argument = ref_conv.MakeArgument(in_host, - wei, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - Bilinear{alpha, beta}, - WeiElementOp{}, - OutElementOp{}, - {}, - {}, - d_tensors); + DeviceMem d_device_buf(sizeof(InDataType) * d.mDesc.GetElementSpaceSize()); + d_device_buf.ToDevice(d.mData.data()); - ref_invoker.Run(ref_argument); + std::array p_ds = { + static_cast(d_device_buf.GetDeviceBuffer())}; + + DeviceMem in_device_buf(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + wei_device_buf.ToDevice(wei.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + ck::ref::naive_conv_bwd_data_multi_abd<0, + 0, + NumDs, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + InDataType>( + static_cast(in_device_buf.GetDeviceBuffer()), + {static_cast(wei_device_buf.GetDeviceBuffer())}, + {static_cast(out_device_buf.GetDeviceBuffer())}, + p_ds, + conv_param, + d_lengths_array, + d_strides_array, + InElementOp{alpha, beta}, + WeiElementOp{}, + OutElementOp{}); + + in_device_buf.FromDevice(in_host.mData.data()); } bool PerformConvDataBilinear(ck::utils::conv::ConvParam& conv_param, diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp index 84d013bca7..f1f985883c 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp @@ -21,7 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -55,38 +55,24 @@ class TestGroupedConvndBwdData : public ::testing::Test void RunReference(ck::utils::conv::ConvParam& conv_param, Tensor& in_host, - Tensor& wei, - Tensor& out) + DeviceMem& wei_device_buf, + DeviceMem& out_device_buf) { - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdData /*Num D Elementwise - Tensors*/ - {}; + // GPU reference + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization - auto ref_invoker = ref_conv.MakeInvoker(); + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + InElementOp{alpha}, + WeiElementOp{}, + OutElementOp{}); - auto ref_argument = ref_conv.MakeArgument(in_host, - wei, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - InElementOp{alpha}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); + ck::hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(in_host.mData.data()); } bool PerformConvDataScale(ck::utils::conv::ConvParam& conv_param, const ck::index_t split_k) @@ -121,10 +107,11 @@ class TestGroupedConvndBwdData : public ::testing::Test DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); - in_device_buf.ToDevice(in_device.mData.data()); out_device_buf.ToDevice(out.mData.data()); wei_device_buf.ToDevice(wei.mData.data()); + RunReference(conv_param, in_host, wei_device_buf, out_device_buf); + std::array out_lengths{}; std::array out_strides{}; std::array wei_lengths{}; @@ -149,8 +136,6 @@ class TestGroupedConvndBwdData : public ::testing::Test copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_right_pads_, input_right_pads); - RunReference(conv_param, in_host, wei, out); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD& out, Tensor& d) { - std::array, NumDs> d_tensors = {d}; - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdWeight{}; + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(in, - wei_host, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - InElementOp{}, - WeiElementOp{alpha, beta}, - OutElementOp{}, - {}, - {}, - d_tensors); + // Prepare D tensor with correct strides for GPU kernel + std::vector d_lengths; + std::vector d_strides; + auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { + const auto& l = desc.GetLengths(); + const auto& s = desc.GetStrides(); + lengths.assign(l.begin(), l.end()); + strides.assign(s.begin(), s.end()); + }; + copy_dims(wei_g_k_c_xs_desc, d_lengths, d_strides); - ref_invoker.Run(ref_argument); + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; + + DeviceMem d_device_buf(sizeof(WeiDataType) * d.mDesc.GetElementSpaceSize()); + d_device_buf.ToDevice(d.mData.data()); + + std::array p_ds = { + static_cast(d_device_buf.GetDeviceBuffer())}; + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_host.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + ck::ref::naive_conv_bwd_weight_multi_abd<0, + 0, + NumDs, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + WeiDataType>( + {static_cast(in_device_buf.GetDeviceBuffer())}, + static_cast(wei_device_buf.GetDeviceBuffer()), + {static_cast(out_device_buf.GetDeviceBuffer())}, + p_ds, + conv_param, + d_lengths_array, + d_strides_array, + InElementOp{}, + WeiElementOp{alpha, beta}, + OutElementOp{}); + + wei_device_buf.FromDevice(wei_host.mData.data()); } bool PerformConvWeightBilinear(ck::utils::conv::ConvParam& conv_param, diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp index 1b37f5eb4e..645aab0151 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp @@ -66,10 +66,10 @@ class TestGroupedConvndFwdBilinear : public ::testing::Test OutDataType, AComputeType, BComputeType, - IndexType>(true, // do_verification + IndexType>(2, // do_verification 1, // init_method: integer value false, // do_log - true, // time_kernel + false, // time_kernel param, bilinear_op); } diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp index 199a50f0fd..e78e61f707 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp @@ -24,6 +24,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" using I8 = int8_t; using F16 = ck::half_t; @@ -131,39 +132,34 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification, wei_device_buf.ToDevice(weight.mData.data()); wei_bias_device_buf.ToDevice(weight_bias.mData.data()); - // Run reference op + // Run GPU reference if(do_verification) { - const std::array, NumAs - 1> elementwise_a_tensors = {input_bias}; - const std::array, NumBs - 1> elementwise_b_tensors = {weight_bias}; - auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer()), + reinterpret_cast(in_bias_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer()), + reinterpret_cast(wei_bias_device_buf.GetDeviceBuffer())}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(input, - weight, - host_output, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - in_element_op, - wei_element_op, - out_element_op, - elementwise_a_tensors, - elementwise_b_tensors); + ck::ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op); - // init host output to zero - host_output.SetZero(); + HIP_CHECK_ERROR(hipDeviceSynchronize()); - ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(host_output.mData.data()); } std::string best_op_name; diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp index d1706d4cec..68a8b016e3 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp @@ -49,7 +49,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, false /*BiasGK*/>( - true, // do_verification + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp index fef485a950..2c04b52b4f 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp @@ -50,7 +50,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, Clamp>( - true, // do_verification + 2, // do_verification: 2 = GPU reference 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp index a78a17cbf4..78cfe126a3 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp @@ -44,7 +44,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, true /*BiasGK*/>( - true, // do_verification + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp index b4179cae62..b2a9cff231 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp @@ -58,10 +58,10 @@ class TestGroupedConvndFwdScale : public ::testing::Test OutDataType, ck::tensor_operation::element_wise::Scale, InDataType, - InDataType>(true, // do_verification + InDataType>(2, // do_verification: 2 = GPU reference 1, // init_method: integer value false, // do_log - true, // time_kernel + false, // time_kernel param); } EXPECT_TRUE(pass); From 3d67e6c4927a9daea9076fab75b23fb44fdc22b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 27 Jan 2026 10:04:11 +0100 Subject: [PATCH 41/42] [CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err (#3624) * [CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err * Update test_grouped_convnd_fwd_tile.cpp * Update test_grouped_convnd_fwd_tile.cpp * Update conv_tuning_params.hpp * clang format fix * Update CMakeLists.txt --- .../factory/helpers/ck/conv_tuning_params.hpp | 3 + .../ck_tile/conv_tile_tuning_params.hpp | 8 +++ .../ck_tile/builder/testing/validation.hpp | 12 +++- .../builder/include/ck_tile/builder/types.hpp | 2 + .../configs/tests/ndhwgc_bf16.conf | 6 +- .../configs/tests/ndhwgc_fp16.conf | 6 +- .../configs/tests/ndhwgc_fp32.conf | 6 +- .../configs/tests/nhwgc_bf16.conf | 6 +- .../configs/tests/nhwgc_fp16.conf | 6 +- .../configs/tests/nhwgc_fp32.conf | 6 +- include/ck_tile/host/check_err.hpp | 2 +- .../grouped_convolution_forward_tile_algs.hpp | 55 +++++++++++++++++-- test/grouped_convnd_fwd/CMakeLists.txt | 13 ++--- .../test_grouped_convnd_fwd_tile.cpp | 29 +++++----- 14 files changed, 114 insertions(+), 46 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 9ed1eebc3c..3b1ea65695 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -58,6 +58,7 @@ consteval BlockGemmSpec SetBlockGemm() case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; + case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile."; case PipelineVersion::WEIGHT_ONLY: throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM."; default: throw "Unknown PipelineVersion"; @@ -92,6 +93,7 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K."; case PipelineVersion::V4: return ck_pipeline::v4; case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM."; + case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE."; case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; default: throw "Unknown GridwiseGemmPipelineVersion"; } @@ -137,6 +139,7 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() case PipelineVersion::V3: return ck_pipeline::v3; case PipelineVersion::V4: return ck_pipeline::v4; case PipelineVersion::V5: return ck_pipeline::v5; + case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile."; case PipelineVersion::WEIGHT_ONLY: throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; default: throw "Unknown block GEMM PipelineVersion"; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp index b7df0e4d0e..12482f3206 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp @@ -91,6 +91,13 @@ struct TilePipelineType using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; }; +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; +}; + template consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() { @@ -103,6 +110,7 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3; case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4; case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5; + case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6; case PipelineVersion::WEIGHT_ONLY: throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; default: throw "Unknown block GEMM PipelineVersion"; diff --git a/experimental/builder/include/ck_tile/builder/testing/validation.hpp b/experimental/builder/include/ck_tile/builder/testing/validation.hpp index 158f271e21..8410a71b15 100644 --- a/experimental/builder/include/ck_tile/builder/testing/validation.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/validation.hpp @@ -51,6 +51,9 @@ struct ValidationReport /// The number of elements which were bitwise 0. uint64_t zero_elements; + // Max error. + double max_error; + /// @brief Check whether both the output and reference tensor were both all zeros. /// /// If both tensors are all zero, it indicates either an incorrect testing setup @@ -133,11 +136,12 @@ bool ValidationReport::check(std::string_view tensor_name, // Initial pass: count errors // Allocate and reset counter - auto d_counters = alloc_buffer(sizeof(uint64_t) * 2); - check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2)); + auto d_counters = alloc_buffer(sizeof(uint64_t) * 3); + check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3)); auto d_error_count = &reinterpret_cast(d_counters.get())[0]; auto d_zero_count = &reinterpret_cast(d_counters.get())[1]; + auto d_max_error = &reinterpret_cast(d_counters.get())[2]; tensor_foreach(descriptor.get_lengths(), [=](auto index) { using CKType = typename factory::internal::DataTypeToCK
::type; @@ -157,6 +161,7 @@ bool ValidationReport::check(std::string_view tensor_name, const auto r = static_cast(type_convert(b)); const auto err = std::abs(o - r); + atomicMax(d_max_error, err); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { // We expect the number of errors to be very low, so just use an atomic @@ -188,6 +193,8 @@ bool ValidationReport::check(std::string_view tensor_name, check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); uint64_t zero_count = 0; check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); + double max_error = 0; + check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost)); // TODO: Gather detailed coordinates. @@ -196,6 +203,7 @@ bool ValidationReport::check(std::string_view tensor_name, .wrong_elements = error_count, .total_elements = descriptor.get_element_size(), .zero_elements = zero_count, + .max_error = max_error, }); return reports_.back().is_ok(); diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index c4cca05e52..dad123bae5 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -157,6 +157,7 @@ enum class PipelineVersion V3, V4, V5, + V6, WEIGHT_ONLY }; @@ -328,6 +329,7 @@ inline std::string_view to_string(PipelineVersion ver) case V3: return "V3"; case V4: return "V4"; case V5: return "V5"; + case V6: return "V6"; case WEIGHT_ONLY: return "WEIGHT_ONLY"; default: return "Unknown"; } diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf index b9704c8100..e7ea32680d 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf index b9704c8100..e7ea32680d 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index a1be8027b2..2ba3b1e7c3 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -19,7 +19,7 @@ namespace ck_tile { /** @brief Maximum number of error values to display when checking errors */ -constexpr int ERROR_DETAIL_LIMIT = 128; +constexpr int ERROR_DETAIL_LIMIT = 16; /** @brief 8-bit floating point type */ using F8 = ck_tile::fp8_t; diff --git a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp index 9f7227a699..9accf6e336 100644 --- a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp +++ b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp @@ -7,6 +7,7 @@ #include "../../experimental/builder/test/utils/conv_algorithm_type_utils.hpp" #include "grouped_convolution_signatures.hpp" +#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp" #include "ck_tile/builder/testing/filter_extent.hpp" #include "ck_tile/builder/testing/conv/fwd.hpp" @@ -14,6 +15,9 @@ #include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/builder/conv_builder.hpp" +// Temporary disable builder validate since we don't have deduced rtol, atol support +#define ENABLE_BUILDER_VALIDATE 0 + namespace ck_tile::builder::profiling { namespace ckb = ck_tile::builder; @@ -117,22 +121,63 @@ run_grouped_conv_forward_tile_algs(const ckt::Args& args, auto ref_conv = ReferenceInstance{}; [[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get()); +#if ENABLE_BUILDER_VALIDATE == 0 + using DataType = + std::conditional_t>; + const auto conv_param = args.to_ck_tile_conv_param(); + + const std::size_t output_bytes_num = conv_param.template GetOutputByte(); + std::vector out(output_bytes_num / sizeof(DataType)); + std::vector ref(output_bytes_num / sizeof(DataType)); + HIP_CHECK_ERROR( + hipMemcpy(&ref.data()[0], reference.get().output, output_bytes_num, hipMemcpyDeviceToHost)); + + const ck_tile::index_t GemmK = std::accumulate(conv_param.filter_spatial_lengths_.cbegin(), + conv_param.filter_spatial_lengths_.cend(), + 1, + std::multiplies()) * + conv_param.C_; + float max_accumulated_value = *std::max_element(ref.begin(), ref.end()); + const auto rtol = ck_tile::get_relative_threshold(GemmK); + const auto atol = + ck_tile::get_absolute_threshold(max_accumulated_value, GemmK); +#endif + [[maybe_unused]] auto run_alg = [&](auto&& run_alg_func) { std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf); if(is_supported) { + best_avg_time = std::min(best_avg_time, avg_time); + best_op_name = best_avg_time < avg_time ? best_op_name : op_name; + std::cout << "Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name + << std::endl; + +#if ENABLE_BUILDER_VALIDATE const auto errors = ckt::validate(args, outputs, reference.get()).get_errors(); for(const auto& error : errors) { valid = false; std::cout << "Number of incorrect values: " << error.wrong_elements - << " Is all zero:" << error.is_all_zero() << std::endl; + << " Is all zero:" << error.is_all_zero() + << " max err: " << error.max_error << std::endl; } - best_avg_time = std::min(best_avg_time, avg_time); - best_op_name = best_avg_time < avg_time ? best_op_name : op_name; - std::cout << "Perf: " << std::setw(10) << avg_time << " ms,"; +#else + HIP_CHECK_ERROR( + hipMemcpy(&out.data()[0], outputs.output, output_bytes_num, hipMemcpyDeviceToHost)); + valid = ck_tile::check_err(out, ref, "Error: Incorrect results!", rtol, atol); +#endif + + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; + } + else + { + std::cout << " " << op_name << std::endl; } - std::cout << " " << op_name << std::endl; }; if constexpr(SIGNATURE == SIGNATURE_NHWGC_FP16_FWD) diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 6f8b71679c..725c5716d9 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -21,13 +21,12 @@ endif() if(GPU_TARGETS MATCHES "gfx9") if(CK_EXPERIMENTAL_BUILDER) - # TODO: Reenable after the instance fixes - # add_executable(test_grouped_convnd_fwd_tile test_grouped_convnd_fwd_tile.cpp) - # target_compile_options(test_grouped_convnd_fwd_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat) - # target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE gtest_main getopt::getopt utility) - # if(TARGET device_grouped_conv_fwd_tile_instances) - # target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE device_grouped_conv_fwd_tile_instances) - # endif() + add_gtest_executable(test_grouped_convnd_fwd_tile test_grouped_convnd_fwd_tile.cpp) + target_compile_options(test_grouped_convnd_fwd_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat) + target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE gtest_main getopt::getopt utility) + if(TARGET device_grouped_conv_fwd_tile_instances) + target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE device_grouped_conv_fwd_tile_instances) + endif() endif() endif() diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp index 068811cf00..fe517572ff 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp @@ -13,6 +13,8 @@ // TODO: Remove limitation of conv fwd gpu reference which does not support right pad #define CK_CONV_FWD_REF_SKIP_RIGHT_PAD_CASES 1 +// TODO: Remove this limitation after gpu reference fix +#define ENABLE_BHALF_GROUPED_CONV_FWD_TESTS 0 static ck::index_t args_mask = 0xffff; static ck::index_t instance_index = -1; @@ -67,7 +69,10 @@ class TestGroupedConvndFwdTile : public ::testing::Test auto inputs = alloc_inputs(args); auto outputs = alloc_outputs(args); - ckt::init_inputs(args, inputs.get()); + ckt::init_tensor_buffer_uniform_fp( + inputs.get().input, args.make_input_descriptor(), -5, 5); + ckt::init_tensor_buffer_uniform_fp( + inputs.get().weight, args.make_weight_descriptor(), -5, 5); std::cout << args.make_input_descriptor() << std::endl; std::cout << args.make_weight_descriptor() << std::endl; @@ -150,13 +155,12 @@ using KernelTypes2d = ::testing::Types, - SignatureDetails<2, - ckb::DataType::BF16, - ckb::DataType::FP32, - ckb::TensorLayout::NHWGC, - ckb::TensorLayout::GKYXC, ckb::TensorLayout::NHWGK>>; +#if ENABLE_BHALF_GROUPED_CONV_FWD_TESTS +SignatureDetails < 2, ckb::DataType::BF16, ckb::DataType::FP32, ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, ckb::TensorLayout::NHWGK >> + ; +#endif using KernelTypes3d = ::testing::Types, - SignatureDetails<3, - ckb::DataType::BF16, - ckb::DataType::FP32, - ckb::TensorLayout::NDHWGC, - ckb::TensorLayout::GKZYXC, ckb::TensorLayout::NDHWGK>>; +#if ENABLE_BHALF_GROUPED_CONV_FWD_TESTS +SignatureDetails < 3, ckb::DataType::BF16, ckb::DataType::FP32, ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, ckb::TensorLayout::NDHWGK >> + ; +#endif template class TestGroupedConvndFwdTile2d : public TestGroupedConvndFwdTile From b66597ed96180ce21e7e6a6678dfc232ed07c800 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 27 Jan 2026 05:07:27 -0800 Subject: [PATCH 42/42] Add build time optimization documentation (#3608) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This document describes techniques for reducing C++ template instantiation overhead in the Composable Kernel codebase, including: - Replacing recursive templates with pack expansion (O(N) → O(1) depth) - Using named functors instead of lambdas to share instantiations - Replacing template recursion with constexpr loops - Using fold expressions for accumulation operations These techniques can significantly reduce build times for template-heavy code. --- include/ck/BUILD_TIME_OPTIMIZATION.md | 225 ++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 include/ck/BUILD_TIME_OPTIMIZATION.md diff --git a/include/ck/BUILD_TIME_OPTIMIZATION.md b/include/ck/BUILD_TIME_OPTIMIZATION.md new file mode 100644 index 0000000000..94b292b878 --- /dev/null +++ b/include/ck/BUILD_TIME_OPTIMIZATION.md @@ -0,0 +1,225 @@ +# Build Time Optimization + +Tracking issue: [#3575](https://github.com/ROCm/composable_kernel/issues/3575) + +This document describes techniques for reducing C++ template instantiation overhead in the Composable Kernel codebase. + +## Why Build Time Matters + +Composable Kernel relies heavily on C++ template metaprogramming to achieve GPU kernels with no runtime abstraction penalty. However, deep template instantiation can significantly impact build times. A single translation unit may trigger hundreds of thousands of template instantiations, with each instantiation adding to compile time. + +## Key Types + +This codebase uses compile-time types to enable zero-overhead abstractions: + +- `Number` - compile-time integer, enables static dispatch and compile-time arithmetic +- `Sequence` - compile-time integer sequence, used for dimension ordering and index manipulation +- `Tuple` - heterogeneous container holding different types, used for tensor descriptors and transforms + +These types allow the compiler to fully unroll loops, eliminate branches, and inline all operations - producing GPU kernels with no runtime abstraction cost. + +## Optimization Techniques + +### 1. Replace Recursive Templates with Pack Expansion + +Recursive template patterns create O(N) instantiation depth - the compiler must instantiate each level before proceeding to the next: + +``` +sequence_gen_impl<5, F> + → sequence_gen_impl<4, F> + → sequence_gen_impl<3, F> + → ... +``` + +Using `__make_integer_seq` (Clang/MSVC) combined with pack expansion reduces this to constant depth - the compiler generates the entire sequence in one step internally, without recursive template instantiation. + +**Before** (O(N) recursive instantiation): + +```cpp +template +struct sequence_gen_impl +{ + using type = typename sequence_gen_impl{}), Is...>::type; +}; + +template +struct sequence_gen_impl<0, F, Is...> +{ + using type = Sequence; +}; +``` + +**After** (constant depth using compiler intrinsic + pack expansion): + +```cpp +namespace detail { + +template +struct sequence_gen_helper +{ + // Apply functor F to all indices via pack expansion + // F{}(Number<0>{}), F{}(Number<1>{}), ..., F{}(Number{}) + template + using apply = Sequence{})...>; +}; + +} // namespace detail + +template +struct sequence_gen +{ + // __make_integer_seq produces + // sequence_gen_helper with constant depth + using type = + typename __make_integer_seq::template apply; +}; +``` + +Note: This document assumes C++17 or later. While `std::make_integer_sequence` (introduced in C++14) is the standard library facility for generating integer sequences, it only produces `std::integer_sequence`. We use `__make_integer_seq` directly because it accepts any template as its first argument, enabling this pattern where the helper class receives the index pack directly. + +### 2. Replace Lambdas with Named Functors + +Each lambda expression creates a unique closure type, causing separate template instantiations at every call site. Named functors share a single type across all uses. + +**Before** (lambda creates unique instantiations at each call site): + +```cpp +// The lambda inside transform_tensor_descriptor: +generate_tuple([](auto i) { return Sequence{}; }, Number{}); +``` + +**After** (named functor shares instantiations): + +```cpp +// Define functor once +struct generate_identity_sequence +{ + template + __host__ __device__ constexpr auto operator()(Number) const + { + return Sequence{}; + } +}; + +// Use everywhere - shares instantiations +generate_tuple(generate_identity_sequence{}, Number{}); +``` + +This reduced `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction). + +**Example: container_concat** + +```cpp +// Before: lambda creates unique type per call site +// (unpack2 applies a functor to all elements from both tuples) +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2([](auto&&... zs) { return make_tuple(forward(zs)...); }, tx, ty); +} + +// After: named functor shares instantiations +struct make_tuple_functor +{ + template + __host__ __device__ constexpr auto operator()(Ts&&... xs) const + { + return make_tuple(forward(xs)...); + } +}; + +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2(make_tuple_functor{}, tx, ty); +} +``` + +This reduced `container_concat` instantiations from 186 to 93 (50% reduction). + +**Example: make_uniform_tuple** + +For patterns that create tuples with repeated values: + +```cpp +// Before: unique lambda type at each call site +generate_tuple([](auto) { return some_value; }, Number{}); + +// After: dedicated helper function +template +__host__ __device__ constexpr auto make_uniform_tuple(T&& value) +{ + return detail::make_uniform_tuple_impl(static_cast(value), make_index_sequence{}); +} + +// Usage +make_uniform_tuple(some_value); +``` + +### 3. Use Constexpr Loops Instead of Template Recursion + +Template recursion creates N template instantiations for N iterations. A constexpr loop executes at compile time but only requires a single template instantiation. While both are O(N) in complexity, constexpr loops are significantly faster because they avoid the overhead of template instantiation. + +**Before** (O(N) template instantiations): + +```cpp +// Simplified example - actual CK code used more complex recursive patterns +template +struct find_source_index_impl +{ + static constexpr index_t value = + (Seq::template At() == Target) ? Pos : find_source_index_impl::value; +}; + +template +struct find_source_index_impl +{ + static constexpr index_t value = -1; // not found +}; +``` + +**After** (single instantiation with constexpr loop): + +```cpp +template +__host__ __device__ constexpr index_t find_source_index(Sequence) +{ + // Simplified example - actual implementation handles empty sequences + constexpr index_t values[] = {Is...}; + for(index_t i = 0; i < sizeof...(Is); ++i) + if(values[i] == Target) return i; + return -1; // not found +} +``` + +This reduced `sequence_map_inverse` instantiations from 45 to 10 (78% reduction) and wall-clock time by 95%. + +### 4. Use Fold Expressions for Accumulation + +Fold expressions (C++17) can replace recursive template patterns for accumulation operations. + +**Before** (uses helper utilities that hide template recursion: `generate_tuple` recursively constructs a tuple of N elements, and `container_reduce` recursively reduces that tuple): + +```cpp +const auto element_space_size = container_reduce( + generate_tuple([&](auto i) { + return (lengths[i] - Number<1>{}) * strides[i]; + }, Number{}), + math::plus{}, Number<1>{}); +``` + +**After** (single fold expression): + +```cpp +template +__host__ __device__ constexpr auto compute_element_space_size( + const Tuple& lengths, + const Tuple& strides, + Sequence) +{ + return (LongNumber<1>{} + ... + + ((lengths[Number{}] - Number<1>{}) * strides[Number{}])); +} +``` + +This reduced `calculate_element_space_size` instantiations from 24 to 10 (58% reduction) and wall-clock time by 73%.