From 9c24b3c23e58879e1f580fafde8597cd6da7c29e Mon Sep 17 00:00:00 2001 From: rocking Date: Sat, 12 Aug 2023 01:31:31 +0800 Subject: [PATCH] Add Normalization splitk instances (#829) * Add normalization splitK to layernorm and groupnorm instances * Fix bug of GetKPerThread() * Refine naming * clang format [ROCm/composable_kernel commit: 03b8119e2eb7c2312eda9d32772c67f598304d86] --- .../gridwise_normalization_splitk_1st.hpp | 26 +++--- .../device_groupnorm_f16_instance.cpp | 2 + .../device_groupnorm_f32_instance.cpp | 2 + ...oupnorm_swish_f16_f32_f32_f16_instance.cpp | 2 + .../device_groupnorm_swish_f16_instance.cpp | 2 + .../device_groupnorm_swish_f32_instance.cpp | 2 + .../device_layernorm2d_f16_instance.cpp | 2 + .../device_layernorm2d_f32_instance.cpp | 2 + .../device_layernorm4d_f16_instance.cpp | 2 + .../device_layernorm4d_f32_instance.cpp | 2 + .../normalization_instance_common.hpp | 79 +++++++++++++++++++ .../profiler/profile_groupnorm_impl.hpp | 4 + .../profiler/profile_layernorm_impl.hpp | 4 + 13 files changed, 120 insertions(+), 11 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp index fc42e97629..0fb961eb7a 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp @@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st static constexpr auto ThreadBufferNumber = Number{}; __device__ static int - GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) + GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) { bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; if(is_rightmost_block) { - int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); - int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock; - int kPerThread = - kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); - int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + int left_kPerBlock = math::integer_divide_ceil(k, kGridSize); + int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1); + int kPerThread = kRightmostBlock < K_BlockTileSize + ? 0 + : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize); + int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize; if(kPerBlockTail > 0) { @@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st } else { - int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); + int kPerBlock = math::integer_divide_ceil(k, kGridSize); return KThreadSliceSize * (kPerBlock / K_BlockTileSize); } } @@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st auto var_global_val_buf = make_dynamic_buffer( p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); - auto threadwise_welford = ThreadwiseWelford(); - int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; - threadwise_welford.max_count_ = - GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id); + auto threadwise_welford = ThreadwiseWelford(); + int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1), + kRaw, + k_grid_size, + block_k_cluster_id, + thread_k_cluster_id); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { mean_thread_buf(I) = type_convert(0.0f); diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp index e3820462cf..762da1c6ae 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp @@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f16_instances( add_device_operation_instances(instances, device_normalization_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp index d85817aad3..44b553bd16 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp @@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f32_instances( add_device_operation_instances(instances, device_normalization_f32_generic_instance{}); add_device_operation_instances(instances, device_normalization_f32_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp index a81f776c0f..aa662b7dfe 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp @@ -18,6 +18,8 @@ void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( instances, device_normalization_f16_f32_f32_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_f32_f32_f16_instances{}); + add_device_operation_instances( + instances, device_normalization_splitk_f16_f32_f32_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp index f4bb8bda81..bc5cd801ae 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp @@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f16_instances( add_device_operation_instances(instances, device_normalization_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp index bbb9bd0fe8..4b2ab33570 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp @@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f32_instances( add_device_operation_instances(instances, device_normalization_f32_generic_instance{}); add_device_operation_instances(instances, device_normalization_f32_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp index 3f7e4aff1a..0d235f1fa7 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp @@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f16_instances( add_device_operation_instances(instances, device_normalization_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp index 1f0db3a036..00039531e1 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp @@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f32_instances( add_device_operation_instances(instances, device_normalization_f32_generic_instance{}); add_device_operation_instances(instances, device_normalization_f32_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp index cb9d72e614..6bc3950062 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp @@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f16_instances( add_device_operation_instances(instances, device_normalization_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp index ed555b840d..b387dc2f3f 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp @@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f32_instances( add_device_operation_instances(instances, device_normalization_f32_generic_instance{}); add_device_operation_instances(instances, device_normalization_f32_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp index b0684962f9..7aa3da8eed 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp +++ b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" #include "ck/utility/data_type.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -43,6 +44,32 @@ using device_normalization_f16_instances = // clang-format on >; +template +using device_normalization_splitk_f16_instances = + // clang-format off + std::tuple < + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl + // clang-format on + >; + template using device_normalization_f16_generic_instance = std::tuple< // clang-format off @@ -76,6 +103,32 @@ using device_normalization_f32_instances = std::tuple< // clang-format on >; +template +using device_normalization_splitk_f32_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl + // clang-format on + >; + template using device_normalization_f32_generic_instance = std::tuple< // clang-format off @@ -109,6 +162,32 @@ using device_normalization_f16_f32_f32_f16_instances = std::tuple< // clang-format on >; +template +using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl + // clang-format on + >; + template using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple< // clang-format off diff --git a/profiler/include/profiler/profile_groupnorm_impl.hpp b/profiler/include/profiler/profile_groupnorm_impl.hpp index ebefe3dad4..f88ba8453c 100644 --- a/profiler/include/profiler/profile_groupnorm_impl.hpp +++ b/profiler/include/profiler/profile_groupnorm_impl.hpp @@ -139,6 +139,10 @@ bool profile_groupnorm_impl(int do_verification, continue; } + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); diff --git a/profiler/include/profiler/profile_layernorm_impl.hpp b/profiler/include/profiler/profile_layernorm_impl.hpp index 2d87c8c8fe..f969646c2f 100644 --- a/profiler/include/profiler/profile_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_layernorm_impl.hpp @@ -155,6 +155,10 @@ bool profile_layernorm_impl(int do_verification, continue; } + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});