From 53eab490622549f0f673e6056c196e6c3f3fac07 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 19 Dec 2023 04:23:11 +0800 Subject: [PATCH] layernorm and groupnorm backward data (#1083) * rename folder * Add type string * Remove typo * Add deviceOp to backward x * Add comment to describe the behavior of backward normalization * Add kernel function, prepare to implement * implement generic kernel * Check vector size * Add sweep once pipeline for small reduce size * Fix bug of KRaw_ error * Fix bug of dx stride * sanity check for mean and rstd * backward x for groupnorm * Add bwd x instance * add layernorm 2d bwd gamma beta instances * Change save mean var type from f32 to f16 in f16 mode * Change the example to f16 * Add groupnorm bwd gamma beta instance * Add groupnorm bwd x instance * Fix naming * Add layernorm bwd x ckprofiler * Add groupnorm bwd x profiler * clang format * Rename bwd x to bwd data * Fix bug of verification in profiler * Add test of layernorm and groupnorm bwd data * Add missing cmake * Add layernorm2d bwd data * rename fwd example * Add groupnorm client example * Fix typo. replace Invarient with Invariant * Add checking before running the best instance [ROCm/composable_kernel commit: a69aa2a11a83dc8e7b39be20aa50ec73d539dc28] --- client_example/01_gemm/gemm.cpp | 1 + .../gemm_add_add_fastgelu.cpp | 1 + .../gemm_add_fastgelu.cpp | 1 + .../gemm_fastgelu.cpp | 1 + .../gemm_add_relu_add_layernorm_welford.cpp | 1 + client_example/05_layernorm/CMakeLists.txt | 3 + .../05_layernorm/layernorm2d_bwd_data.cpp | 170 ++++++ .../05_layernorm/layernorm2d_fwd.cpp | 3 +- .../05_layernorm/layernorm4d_fwd.cpp | 3 +- client_example/06_softmax/softmax4d.cpp | 1 + .../elementwise_layernorm2d.cpp | 1 + .../gemm_add_multiply.cpp | 1 + client_example/18_groupnorm/CMakeLists.txt | 7 +- .../18_groupnorm/groupnorm_bwd_data.cpp | 182 ++++++ ...norm_swish.cpp => groupnorm_swish_fwd.cpp} | 0 .../20_splitk_gemm/splitK_gemm_fp16_f8.cpp | 1 + .../elementwise_transpose_3d.cpp | 1 + example/53_layernorm2d_bwd/CMakeLists.txt | 1 + .../layernorm2d_bwd_fp32.cpp} | 97 ++- example/53_layernorm_bwd/CMakeLists.txt | 1 - example/54_groupnorm_bwd/CMakeLists.txt | 2 +- ...rm_bwd_fp16.cpp => groupnorm_bwd_fp32.cpp} | 101 +++- .../device/device_normalization_bwd_data.hpp | 59 ++ .../device_normalization_bwd_data_impl.hpp | 465 +++++++++++++++ ...vice_normalization_bwd_gamma_beta_impl.hpp | 32 +- .../impl/device_normalization_fwd_impl.hpp | 6 +- .../device_normalization_fwd_splitk_impl.hpp | 4 +- .../gridwise_normalization_bwd_data.hpp | 554 ++++++++++++++++++ .../gridwise_normalization_bwd_gamma_beta.hpp | 11 +- .../cpu/reference_groupnorm_bwd.hpp | 25 + .../cpu/reference_layernorm_bwd.hpp | 24 + .../gpu/groupnorm_bwd_data.hpp | 64 ++ .../gpu/layernorm_bwd_data.hpp | 84 +++ .../gpu/normalization_fwd.hpp | 8 +- .../gpu/normalization_fwd_swish.hpp | 4 +- .../gpu/normalization_bwd_data/CMakeLists.txt | 8 + ...device_groupnorm_bwd_data_f32_instance.cpp | 22 + ...vice_layernorm2d_bwd_data_f16_instance.cpp | 23 + ...vice_layernorm2d_bwd_data_f32_instance.cpp | 23 + ...normalization_bwd_data_instance_common.hpp | 73 +++ .../CMakeLists.txt | 8 + ..._groupnorm_bwd_gamma_beta_f32_instance.cpp | 23 + ...ayernorm2d_bwd_gamma_beta_f16_instance.cpp | 24 + ...ayernorm2d_bwd_gamma_beta_f32_instance.cpp | 24 + ...ization_bwd_gamma_beta_instance_common.hpp | 73 +++ .../device_groupnorm_fwd_f16_instance.cpp | 2 +- ...evice_groupnorm_fwd_swish_f16_instance.cpp | 2 +- .../device_layernorm2d_fwd_f16_instance.cpp | 2 +- .../device_layernorm4d_fwd_f16_instance.cpp | 2 +- .../normalization_fwd_instance_common.hpp | 74 +-- .../profile_groupnorm_bwd_data_impl.hpp | 250 ++++++++ .../profile_layernorm_bwd_data_impl.hpp | 255 ++++++++ profiler/src/CMakeLists.txt | 3 + profiler/src/profile_groupnorm_bwd_data.cpp | 104 ++++ profiler/src/profile_groupnorm_fwd.cpp | 2 +- profiler/src/profile_layernorm_bwd_data.cpp | 112 ++++ profiler/src/profile_layernorm_fwd.cpp | 4 +- test/CMakeLists.txt | 1 + test/normalization_bwd_data/CMakeLists.txt | 13 + .../test_groupnorm_bwd_data_fp32.cpp | 51 ++ .../test_layernorm2d_bwd_data_fp32.cpp | 48 ++ .../test_groupnorm_fwd_fp16.cpp | 4 +- .../test_groupnorm_fwd_fp32.cpp | 2 +- .../test_layernorm2d_fwd_fp16.cpp | 4 +- .../test_layernorm4d_fwd_fp16.cpp | 4 +- 65 files changed, 3050 insertions(+), 110 deletions(-) create mode 100644 client_example/05_layernorm/layernorm2d_bwd_data.cpp create mode 100644 client_example/18_groupnorm/groupnorm_bwd_data.cpp rename client_example/18_groupnorm/{groupnorm_swish.cpp => groupnorm_swish_fwd.cpp} (100%) create mode 100644 example/53_layernorm2d_bwd/CMakeLists.txt rename example/{53_layernorm_bwd/layernorm2d_bwd_fp16.cpp => 53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp} (62%) delete mode 100644 example/53_layernorm_bwd/CMakeLists.txt rename example/54_groupnorm_bwd/{groupnorm_bwd_fp16.cpp => groupnorm_bwd_fp32.cpp} (62%) create mode 100644 include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_groupnorm_bwd_data_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/normalization_bwd_data_instance_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_groupnorm_bwd_gamma_beta_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/normalization_bwd_gamma_beta_instance_common.hpp create mode 100644 profiler/include/profiler/profile_groupnorm_bwd_data_impl.hpp create mode 100644 profiler/include/profiler/profile_layernorm_bwd_data_impl.hpp create mode 100644 profiler/src/profile_groupnorm_bwd_data.cpp create mode 100644 profiler/src/profile_layernorm_bwd_data.cpp create mode 100644 test/normalization_bwd_data/CMakeLists.txt create mode 100644 test/normalization_bwd_data/test_groupnorm_bwd_data_fp32.cpp create mode 100644 test/normalization_bwd_data/test_layernorm2d_bwd_data_fp32.cpp diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp index c37f208db1..11f9222873 100644 --- a/client_example/01_gemm/gemm.cpp +++ b/client_example/01_gemm/gemm.cpp @@ -185,6 +185,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp index 756889562e..e845c120d8 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -204,6 +204,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp index 8d2a8c234a..e77b67c905 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -197,6 +197,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp index c02df018fd..7648da9cac 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -190,6 +190,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp index 3d5fb60048..93f8847c62 100644 --- a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp @@ -200,6 +200,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/05_layernorm/CMakeLists.txt b/client_example/05_layernorm/CMakeLists.txt index 9cbfc2b763..246f877cde 100644 --- a/client_example/05_layernorm/CMakeLists.txt +++ b/client_example/05_layernorm/CMakeLists.txt @@ -1,3 +1,6 @@ +add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp) +target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations) + add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp) target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/05_layernorm/layernorm2d_bwd_data.cpp b/client_example/05_layernorm/layernorm2d_bwd_data.cpp new file mode 100644 index 0000000000..9f26cb6840 --- /dev/null +++ b/client_example/05_layernorm/layernorm2d_bwd_data.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DXDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * M * N); + SimpleDeviceMem x_dev(sizeof(XDataType) * M * N); + SimpleDeviceMem gamma_dev(sizeof(GammaDataType) * N); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem dx_dev(sizeof(DXDataType) * M * N); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(DYDataType) * M * N + sizeof(XDataType) * M * N + + sizeof(GammaDataType) * N + sizeof(MeanInvStdDataType) * M * 2 + + sizeof(DXDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/05_layernorm/layernorm2d_fwd.cpp b/client_example/05_layernorm/layernorm2d_fwd.cpp index 19ddd614de..420225b613 100644 --- a/client_example/05_layernorm/layernorm2d_fwd.cpp +++ b/client_example/05_layernorm/layernorm2d_fwd.cpp @@ -16,7 +16,7 @@ using XDataType = ck::half_t; using GammaDataType = ck::half_t; using BetaDataType = ck::half_t; using YDataType = ck::half_t; -using SaveMeanInvStdDataType = float; +using SaveMeanInvStdDataType = ck::half_t; using PassThrough = ck::tensor_operation::element_wise::PassThrough; #define SAVE_MEAN_INV_STD @@ -150,6 +150,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/05_layernorm/layernorm4d_fwd.cpp b/client_example/05_layernorm/layernorm4d_fwd.cpp index 9a7ecfd87e..fa408dc751 100644 --- a/client_example/05_layernorm/layernorm4d_fwd.cpp +++ b/client_example/05_layernorm/layernorm4d_fwd.cpp @@ -16,7 +16,7 @@ using XDataType = ck::half_t; using GammaDataType = ck::half_t; using BetaDataType = ck::half_t; using YDataType = ck::half_t; -using SaveMeanInvStdDataType = float; +using SaveMeanInvStdDataType = ck::half_t; using PassThrough = ck::tensor_operation::element_wise::PassThrough; #define SAVE_MEAN_INV_STD @@ -155,6 +155,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/06_softmax/softmax4d.cpp b/client_example/06_softmax/softmax4d.cpp index 2ccad27a88..a62af76635 100644 --- a/client_example/06_softmax/softmax4d.cpp +++ b/client_example/06_softmax/softmax4d.cpp @@ -140,6 +140,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index bc4a6fe0bf..8326f0758c 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -142,6 +142,7 @@ int main() << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp index c74d7c6bd8..cde4713b23 100644 --- a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp +++ b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp @@ -204,6 +204,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/18_groupnorm/CMakeLists.txt b/client_example/18_groupnorm/CMakeLists.txt index dee85f9a60..deb50f6fce 100644 --- a/client_example/18_groupnorm/CMakeLists.txt +++ b/client_example/18_groupnorm/CMakeLists.txt @@ -1,2 +1,5 @@ -add_executable(client_groupnorm_swish groupnorm_swish.cpp) -target_link_libraries(client_groupnorm_swish PRIVATE composable_kernel::device_other_operations) +add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp) +target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations) + +add_executable(client_groupnorm_swish_fwd groupnorm_swish_fwd.cpp) +target_link_libraries(client_groupnorm_swish_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/18_groupnorm/groupnorm_bwd_data.cpp b/client_example/18_groupnorm/groupnorm_bwd_data.cpp new file mode 100644 index 0000000000..01ca21ba57 --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_bwd_data.cpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DXDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t length = N * H * W * G * C; + + std::vector strideDy = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector strideX = strideDy; + std::vector strideDx = strideDy; + + std::vector strideGamma = {0, 0, 0, C, 1}; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * length); + SimpleDeviceMem x_dev(sizeof(XDataType) * length); + SimpleDeviceMem gamma_dev(sizeof(GammaDataType) * G * C); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem dx_dev(sizeof(DXDataType) * length); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(DYDataType) * length + sizeof(XDataType) * length + + sizeof(GammaDataType) * G * C + + sizeof(MeanInvStdDataType) * N * G * 2 + + sizeof(DXDataType) * length; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/18_groupnorm/groupnorm_swish.cpp b/client_example/18_groupnorm/groupnorm_swish_fwd.cpp similarity index 100% rename from client_example/18_groupnorm/groupnorm_swish.cpp rename to client_example/18_groupnorm/groupnorm_swish_fwd.cpp diff --git a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp index 94a57cd029..a740c22f91 100644 --- a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp +++ b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp @@ -191,6 +191,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp index fb63e20147..65ba46fcd2 100644 --- a/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp +++ b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp @@ -117,6 +117,7 @@ int main() << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/example/53_layernorm2d_bwd/CMakeLists.txt b/example/53_layernorm2d_bwd/CMakeLists.txt new file mode 100644 index 0000000000..a58b1109f7 --- /dev/null +++ b/example/53_layernorm2d_bwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp) diff --git a/example/53_layernorm_bwd/layernorm2d_bwd_fp16.cpp b/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp similarity index 62% rename from example/53_layernorm_bwd/layernorm2d_bwd_fp16.cpp rename to example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp index f2e6bfb44d..0b0a3e72ad 100644 --- a/example/53_layernorm_bwd/layernorm2d_bwd_fp16.cpp +++ b/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp @@ -15,16 +15,17 @@ #include "ck/library/utility/literals.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" -using DYDataType = ck::half_t; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; using MeanInvStdDataType = float; -using DGammaDataType = ck::half_t; -using DBetaDataType = ck::half_t; -using DXDataType = ck::half_t; +using DGammaDataType = float; +using DBetaDataType = float; +using DXDataType = float; using ComputeDataType = float; constexpr int Rank = 2; @@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1; // inv_std: [M, 1] // Output shape +// dx: [M, N] // dgamma: [1, N] // dbeta: [1, N] @@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1; // dbeta = reduce_sum(dy, axis=0) // [CAUSION] -// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension -// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different +// In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant +// dimension, K is reduced dimension Hence, M in this example and +// DeviceNormalizationBwdGammaBetaImpl is different +using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< + DYDataType, + XDataType, + GammaDataType, + MeanInvStdDataType, + ComputeDataType, + DXDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 1, // MThreadSliceSize + 4, // KThreadSliceSize + true, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + true, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsGammaFastestDimReduced + 4, // GammaSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + true, // IsDXFastestDimReduced + 4>; // DXDstVectorSize + using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< DYDataType, XDataType, @@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio Rank, NumReduceDim, 256, // BlockSize - 8, // ClusterInvarient - 32, // ClusterReduce - 8, // SliceInvarient - 1, // SliceReduce + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 4, // MThreadSliceSize + 1, // KThreadSliceSize false, // IsDYFastestDimReduced - 8, // DYSrcVectorSize + 4, // DYSrcVectorSize false, // IsXFastestDimReduced - 8, // XSrcVectorSize + 4, // XSrcVectorSize true, // IsMeanInvStdFastestDimReduced 1, // MeanInvStdSrcVectorSize - 1, // DGammaDstVectorSize - 1>; // DBetaDstVectorSize + 4, // DGammaDstVectorSize + 4>; // DBetaDstVectorSize int main() { @@ -96,16 +124,48 @@ int main() DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); dy_dev.ToDevice(dy.mData.data()); x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); mean_dev.ToDevice(mean.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data()); + // backward x + auto x_device_instance = XDeviceInstance{}; + + auto x_argument_ptr = x_device_instance.MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto x_invoker_ptr = x_device_instance.MakeInvokerPointer(); + x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // backward gamma & beta auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_argument_ptr = gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths @@ -126,7 +186,8 @@ int main() if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) { - std::cout << "The runtime parameters are not supported" << std::endl; + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; return 1; }; @@ -156,9 +217,11 @@ int main() dgamma_dev.FromDevice(dgamma.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data()); + dx_dev.FromDevice(dx.mData.data()); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3); } return (pass ? 0 : 1); diff --git a/example/53_layernorm_bwd/CMakeLists.txt b/example/53_layernorm_bwd/CMakeLists.txt deleted file mode 100644 index 24db221523..0000000000 --- a/example/53_layernorm_bwd/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_example_executable(example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp) diff --git a/example/54_groupnorm_bwd/CMakeLists.txt b/example/54_groupnorm_bwd/CMakeLists.txt index ac548cbc79..2cb103499c 100644 --- a/example/54_groupnorm_bwd/CMakeLists.txt +++ b/example/54_groupnorm_bwd/CMakeLists.txt @@ -1 +1 @@ -add_example_executable(example_groupnorm_bwd_fp16 groupnorm_bwd_fp16.cpp) +add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp) diff --git a/example/54_groupnorm_bwd/groupnorm_bwd_fp16.cpp b/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp similarity index 62% rename from example/54_groupnorm_bwd/groupnorm_bwd_fp16.cpp rename to example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp index 1537a014d4..6cf1b2ff91 100644 --- a/example/54_groupnorm_bwd/groupnorm_bwd_fp16.cpp +++ b/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp @@ -15,23 +15,58 @@ #include "ck/library/utility/literals.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" -using DYDataType = ck::half_t; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; using MeanInvStdDataType = float; -using DGammaDataType = ck::half_t; -using DBetaDataType = ck::half_t; -using DXDataType = ck::half_t; +using DGammaDataType = float; +using DBetaDataType = float; +using DXDataType = float; using ComputeDataType = float; constexpr int Rank = 5; constexpr int NumReduceDim = 3; // Grouprnorm -// kernel: M , K +// kernel 1: M , K +// dy: N, H, W, G, C -> N * G, H * W * C +// x: N, H, W, G, C -> N * G, H * W * C +// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C +// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1 +// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1 + +// dx: N, H, W, G, C -> N * G, H * W * C + +using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< + DYDataType, + XDataType, + GammaDataType, + MeanInvStdDataType, + ComputeDataType, + DXDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 1, // MThreadSliceSize + 4, // KThreadSliceSize + true, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + true, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsGammaFastestDimReduced + 4, // GammaSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + true, // IsDXFastestDimReduced + 4>; // DXDstVectorSize + +// kernel 2: M , K // dy: N, H, W, G, C -> G * C, N * H * W // x: N, H, W, G, C -> G * C, N * H * W // mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 @@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio Rank, NumReduceDim, 256, // BlockSize - 8, // ClusterInvarient + 8, // ClusterInvariant 32, // ClusterReduce - 8, // SliceInvarient + 4, // SliceInvariant 1, // SliceReduce false, // IsDYFastestDimReduced - 8, // DYSrcVectorSize + 4, // DYSrcVectorSize false, // IsXFastestDimReduced - 8, // XSrcVectorSize + 4, // XSrcVectorSize false, // IsMeanInvStdFastestDimReduced 1, // MeanInvStdSrcVectorSize - 1, // DGammaDstVectorSize - 1>; // DBetaDstVectorSize + 4, // DGammaDstVectorSize + 4>; // DBetaDstVectorSize int main() { @@ -93,20 +128,55 @@ int main() DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); dy_dev.ToDevice(dy.mData.data()); x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); mean_dev.ToDevice(mean.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data()); std::vector dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; std::vector xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; + std::vector gammaStrides = {0, 0, 0, C, 1}; std::vector meanStrides = {G, 0, 0, 1, 0}; std::vector invStdStrides = {G, 0, 0, 1, 0}; + std::vector dxStrides{dx.mDesc.GetStrides().begin(), dx.mDesc.GetStrides().end()}; + + // backward x + auto x_device_instance = XDeviceInstance{}; + + auto x_argument_ptr = x_device_instance.MakeArgumentPointer({N, H, W, G, C}, // lengths + dyStrides, // dyStrides + xStrides, // xStrides + gammaStrides, // gammaStrides + meanStrides, // meanStrides + invStdStrides, // invStdStrides + dxStrides, // dxStrides + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto x_invoker_ptr = x_device_instance.MakeInvokerPointer(); + x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // backward gamma & beta auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_argument_ptr = @@ -128,7 +198,8 @@ int main() if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) { - std::cout << "The runtime parameters are not supported" << std::endl; + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; return 1; }; @@ -158,9 +229,11 @@ int main() dgamma_dev.FromDevice(dgamma.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data()); + dx_dev.FromDevice(dx.mData.data()); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3); } return (pass ? 0 : 1); diff --git a/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp new file mode 100644 index 0000000000..327acfeb53 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationBwdDataPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp new file mode 100644 index 0000000000..86689af0b7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp @@ -0,0 +1,465 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// M is Invariant dimension, K is reduced dimension +namespace ck { +namespace tensor_operation { +namespace device { +template +__global__ void +kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k, + const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K mean_grid_desc_m_k, + const GridDesc_M_K inv_std_grid_desc_m_k, + const GridDesc_M_K dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) +{ + GridwiseNormalizationBwd::Run(dy_grid_desc_m_k, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + mean_grid_desc_m_k, + inv_std_grid_desc_m_k, + dx_grid_desc_m_k, + num_k_block_tile_iteration, + p_dy_global, + p_x_global, + p_gamma_global, + p_mean_global, + p_inv_std_global, + p_dx_global); +}; + +template +struct DeviceNormalizationBwdDataImpl : public DeviceNormalizationBwdData +{ + static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; + static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; + static constexpr index_t GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0; + static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; + static constexpr index_t DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) || + (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), + "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " + "check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize % DXDstVectorSize == 0) || + (DXDstVectorDim == 1 && KThreadSliceSize % DXDstVectorSize == 0)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); + + static auto Make2dDescriptor(const std::vector& lengths, + const std::vector& strides, + int numBlockTileIteration) + { + const auto tupleLengths = make_tuple_from_array(lengths, Number{}); + const auto tupleStrides = make_tuple_from_array(strides, Number{}); + + const auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(lengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + + return transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto pad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, pad_M), + make_right_pad_transform(reduceLength, pad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k_padded; + } + + using GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1)); + + using GridwiseNormalizationBwdDataGeneric = + GridwiseNormalizationBwdData_mk_to_mk; + + using GridwiseNormalizationBwdDataSweepOnce = + GridwiseNormalizationBwdData_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const DYDataType* p_dy, + const XDataType* p_x, + const GammaDataType* p_gamma, + const MeanInvStdDataType* p_mean, + const MeanInvStdDataType* p_invStd, + DXDataType* p_dx) + : p_dy_(p_dy), + p_x_(p_x), + p_gamma_(p_gamma), + p_mean_(p_mean), + p_invStd_(p_invStd), + p_dx_(p_dx) + { + lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + dyStrides_ = shuffle_tensor_dimensions(dyStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + meanStrides_ = shuffle_tensor_dimensions(meanStrides, reduceDims); + invStdStrides_ = + shuffle_tensor_dimensions(invStdStrides, reduceDims); + dxStrides_ = shuffle_tensor_dimensions(dxStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(lengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + dy_grid_desc_m_k_ = Make2dDescriptor(lengths_, dyStrides_, numBlockTileIteration_); + x_grid_desc_m_k_ = Make2dDescriptor(lengths_, xStrides_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + Make2dDescriptor(lengths_, gammaStrides_, numBlockTileIteration_); + mean_grid_desc_m_k_ = Make2dDescriptor(lengths_, meanStrides_, numBlockTileIteration_); + inv_std_grid_desc_m_k_ = + Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_); + dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_); + + isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize; + } + + const DYDataType* p_dy_; + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const MeanInvStdDataType* p_mean_; + const MeanInvStdDataType* p_invStd_; + DXDataType* p_dx_; + + std::vector lengths_; + std::vector dyStrides_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector meanStrides_; + std::vector invStdStrides_; + std::vector dxStrides_; + + int numBlockTileIteration_; + size_t gridSize_; + + // tensor descriptor + GridDesc_M_K dy_grid_desc_m_k_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K mean_grid_desc_m_k_; + GridDesc_M_K inv_std_grid_desc_m_k_; + GridDesc_M_K dx_grid_desc_m_k_; + + bool isSweeponce_; + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + auto KernelSelector(bool isSweepOnce) + { + return isSweepOnce + ? kernel_normalization_bwd_data + : kernel_normalization_bwd_data; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = KernelSelector(arg.isSweeponce_); + + return launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.dy_grid_desc_m_k_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.mean_grid_desc_m_k_, + arg.inv_std_grid_desc_m_k_, + arg.dx_grid_desc_m_k_, + arg.numBlockTileIteration_, + arg.p_dy_, + arg.p_x_, + arg.p_gamma_, + arg.p_mean_, + arg.p_invStd_, + arg.p_dx_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + template + bool IsVectorDimSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(SrcVectorSize == 1) + return true; + + // Fastest dimension is not reduced + if constexpr(SrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + return false; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0) + return false; + } + else // Fastest dimension is reduced + { + if(strides[Rank - 1] != 1) + return false; + + if(lengths[Rank - 1] % SrcVectorSize != 0) + return false; + }; + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + bool pass = true; + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dyStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->xStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->gammaStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->meanStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->invStdStrides_); + + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dxStrides_); + return pass; + } + + std::unique_ptr MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) override + { + if(lengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank || + gammaStrides.size() != Rank || meanStrides.size() != Rank || + invStdStrides.size() != Rank || dxStrides.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + dyStrides, + xStrides, + gammaStrides, + meanStrides, + invStdStrides, + dxStrides, + reduceDims, + static_cast(p_dy), + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_mean), + static_cast(p_invStd), + static_cast(p_dx)); + } + + virtual std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdDataImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "DYSrcVectorSize" << DYSrcVectorSize << "_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_MeanRstd" << MeanInvStdSrcVectorSize << "_Dx" << DXDstVectorSize; + str << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp index 43d2db4624..c35652bfe8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp @@ -14,7 +14,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" -// M is invarient dimension, K is reduced dimension +// M is Invariant dimension, K is reduced dimension namespace ck { namespace tensor_operation { namespace device { @@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl Rank, NumReduceDim> { - static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; @@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); - static_assert( - ((MThreadSliceSize % DGammaDstVectorSize == 0) || - (MThreadSliceSize % DBetaDstVectorSize == 0)), - "Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please " - "check!"); - static_assert( (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " "check!"); + static_assert( + ((MThreadSliceSize % DGammaDstVectorSize == 0) || + (MThreadSliceSize % DBetaDstVectorSize == 0)), + "Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please " + "check!"); + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; @@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl GridDesc_M dgamma_grid_desc_m_; GridDesc_M dbeta_grid_desc_m_; - index_t MRaw_; // invarient length + index_t MRaw_; // Invariant length index_t KRaw_; // reduce length }; @@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl { return std::make_unique(); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdGammaBetaImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "VectorSize_DY" << DYSrcVectorSize << "_X" << XSrcVectorSize ; + str << "_DGamma" << DGammaDstVectorSize << "_DBeta" << DBetaDstVectorSize << ">"; + // clang-format on + + return str.str(); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp index 254d60ea38..fa7b9cbda0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp @@ -19,7 +19,7 @@ namespace tensor_operation { namespace device { // Y = Normalization(X, Beta, Gamma) -// M: Invarient length +// M: Invariant length // K: Reduce length (Calculate mean and variance along K dimension) // eg. Length = [N, C, H, W], reduce dim = [C, H, W] // Then, M = N, K = C * H * W @@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwdinvariant_lowest_length_); - if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp index 6a117920f4..7fe6502bab 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp @@ -108,7 +108,7 @@ namespace tensor_operation { namespace device { // Y = Normalization(X, Beta, Gamma) -// M: Invarient length +// M: Invariant length // K: Reduce length (Calculate mean and variance along K dimension) // eg. Length = [N, C, H, W], reduce dim = [C, H, W] // Then, M = N, K = C * H * W @@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd +struct GridwiseNormalizationBwdData_mk_to_mk +{ + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || + (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) || + (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) || + (GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) || + (DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_K = Sequence; + + using DYThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using XThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using GammaThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using MeanInvStdThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using DXThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder; + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& mean_grid_desc_m_k, + const GridDesc_M_K& inv_std_grid_desc_m_k, + const GridDesc_M_K& dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + // Global + const auto dy_global_val_buf = make_dynamic_buffer( + p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize()); + + const auto inv_std_global_val_buf = make_dynamic_buffer( + p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize()); + + auto dx_global_val_buf = make_dynamic_buffer( + p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize()); + + // VGPR + auto dy_thread_buf = StaticBuffer{}; + + auto x_thread_buf = StaticBuffer{}; + + auto gamma_thread_buf = StaticBuffer{}; + + auto mean_thread_buf = StaticBuffer{}; + + auto inv_std_thread_buf = StaticBuffer{}; + + auto dx_thread_buf = StaticBuffer{}; + + auto ds_thread_buf = + StaticBuffer{}; + + auto db_thread_buf = + StaticBuffer{}; + + // thread id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // IO + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_mean_load = + ThreadwiseTensorSliceTransfer_v2( + mean_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_inv_std_load = + ThreadwiseTensorSliceTransfer_v2( + inv_std_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dx_store = + ThreadwiseTensorSliceTransfer_v1r3( + dx_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + ComputeDataType reduce_size = type_convert( + dy_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + ds_thread_buf(I) = type_convert(0.0f); + db_thread_buf(I) = type_convert(0.0f); + }); + + // Separate sweep once and sweep twice pipeline + // Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K + // we don't need to use loop to read x, dy, gamma twice + if constexpr(SweepOnce) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + } // end of sweep once + else // Sweep Twice pipeline + { + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + } // end of first sweep + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + // reverse read for using dy, gamma and x in the cache + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + // move to tail + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); + + // move from start to tail + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp index e80c360bb2..21248e3a0a 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp @@ -35,7 +35,7 @@ template struct GridwiseNormalizationBwdGammaBeta_mk_to_k { - // if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); @@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + // do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize, + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + using ThreadClusterLengths_M_K = Sequence; using DYThreadBufferDimAccessOrder = diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp index 37f875d07a..2cc9a50ba0 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp @@ -16,6 +16,31 @@ namespace ck { namespace tensor_operation { namespace host { +// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size): +// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True) +// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True) +// b = (db * x_mean - ds) * rstd ** (3) / reduce_size +// c = -b * x_mean - db * rstd / reduce_size +// dx = rstd * dy * gamma + b * x + c +// return dx + +// def normalization_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis): +// # Assume shape of gamma and beta are the same +// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True) +// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True) +// return dgamma, dbeta + +// def groupnorm_backward(dy, x, gamma, x_mean, rstd): +// # dy, x = [N, H, W, G, C], gamma = [1, 1, 1, G, C], x_mean, rstd = [N, 1, 1, G, 1] +// N, H, W, G, C = x.shape +// dx = normalization_input_backward( +// dy, x, gamma, x_mean, rstd, (1, 2, 4), H * W * C) +// dgamma, dbeta = normalization_gamma_beta_backward( +// dy, x, x_mean, rstd, (0, 1, 2)) +// return dx, dgamma, dbeta + +// Reference (Layernorm and groupnorm): +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L655 template +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP32 +// FP32 +void add_device_groupnorm_bwd_data_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceNormalizationBwdData> +{ + using DeviceOp = DeviceNormalizationBwdData; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_groupnorm_bwd_data_f32_instances(op_ptrs); + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp new file mode 100644 index 0000000000..c46cdec336 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +// FP16 +void add_device_layernorm2d_bwd_data_f16_instances( + std::vector>>&); +#endif +#ifdef CK_ENABLE_FP32 +// FP32 +void add_device_layernorm2d_bwd_data_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceNormalizationBwdData> +{ + using DeviceOp = DeviceNormalizationBwdData; + + static auto GetInstances() + { + std::vector> op_ptrs; +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_layernorm2d_bwd_data_f16_instances(op_ptrs); + } + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_layernorm2d_bwd_data_f32_instances(op_ptrs); + } + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp index 29c9f8b2c0..d19e50297f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp @@ -20,15 +20,15 @@ namespace instance { // FP16 void add_device_normalization_fwd_rank_2_1_f16_instances( std::vector< - std::unique_ptr>>&); + std::unique_ptr>>&); void add_device_normalization_fwd_rank_4_3_f16_instances( std::vector< - std::unique_ptr>>&); + std::unique_ptr>>&); void add_device_normalization_fwd_rank_5_3_f16_instances( std::vector< - std::unique_ptr>>&); + std::unique_ptr>>&); #endif #ifdef CK_ENABLE_FP32 // FP32 @@ -76,7 +76,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(Rank == 2 && NumReduceDim == 1) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp index 563f164fd2..cc15e7bacc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp @@ -19,7 +19,7 @@ namespace instance { // FP16 void add_device_normalization_fwd_rank_5_3_swish_f16_instances( - std::vector>>&); + std::vector>>&); // FP32 void add_device_normalization_fwd_rank_5_3_swish_f32_instances( @@ -61,7 +61,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(Rank == 5 && NumReduceDim == 3) { diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000..9f3dd9d94c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt @@ -0,0 +1,8 @@ +set(DEVICE_NORMALIZATION_bwd_data_INSTANCES) + +list(APPEND DEVICE_NORMALIZATION_bwd_data_INSTANCES + device_groupnorm_bwd_data_f32_instance.cpp + device_layernorm2d_bwd_data_f16_instance.cpp + device_layernorm2d_bwd_data_f32_instance.cpp) + +add_instance_library(device_normalization_bwd_data_instance ${DEVICE_NORMALIZATION_bwd_data_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_groupnorm_bwd_data_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_groupnorm_bwd_data_f32_instance.cpp new file mode 100644 index 0000000000..7b4974d648 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_groupnorm_bwd_data_f32_instance.cpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_data_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_groupnorm_bwd_data_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_groupnorm_bwd_data_f32_generic_instance{}); + add_device_operation_instances(instances, device_groupnorm_bwd_data_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f16_instance.cpp new file mode 100644 index 0000000000..097f9a3814 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f16_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_data_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_layernorm2d_bwd_data_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_layernorm_bwd_data_f16_generic_instance<2, 1>{}); + add_device_operation_instances(instances, device_layernorm_bwd_data_f16_instances<2, 1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f32_instance.cpp new file mode 100644 index 0000000000..d885a77e5c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f32_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_data_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_layernorm2d_bwd_data_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_layernorm_bwd_data_f32_generic_instance<2, 1>{}); + add_device_operation_instances(instances, device_layernorm_bwd_data_f32_instances<2, 1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/normalization_bwd_data_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/normalization_bwd_data_instance_common.hpp new file mode 100644 index 0000000000..4f72a8782b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/normalization_bwd_data_instance_common.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using device_layernorm_bwd_data_f16_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize> + DeviceNormalizationBwdDataImpl, + DeviceNormalizationBwdDataImpl, + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +template +using device_layernorm_bwd_data_f16_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +template +using device_layernorm_bwd_data_f32_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize> + DeviceNormalizationBwdDataImpl, + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +template +using device_layernorm_bwd_data_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +using device_groupnorm_bwd_data_f32_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize> + DeviceNormalizationBwdDataImpl, + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +using device_groupnorm_bwd_data_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt new file mode 100644 index 0000000000..686fb5e665 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt @@ -0,0 +1,8 @@ +set(DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES) + +list(APPEND DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES + device_groupnorm_bwd_gamma_beta_f32_instance.cpp + device_layernorm2d_bwd_gamma_beta_f16_instance.cpp + device_layernorm2d_bwd_gamma_beta_f32_instance.cpp) + +add_instance_library(device_normalization_bwd_gamma_beta_instance ${DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_groupnorm_bwd_gamma_beta_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_groupnorm_bwd_gamma_beta_f32_instance.cpp new file mode 100644 index 0000000000..8eaf8a5684 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_groupnorm_bwd_gamma_beta_f32_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_gamma_beta_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_groupnorm_bwd_gamma_beta_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_groupnorm_bwd_gamma_beta_f32_instances{}); + add_device_operation_instances(instances, + device_groupnorm_bwd_gamma_beta_f32_generic_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp new file mode 100644 index 0000000000..aa399f56ec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_gamma_beta_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_layernorm_bwd_gamma_beta_f16_generic_instance<2, 1>{}); + add_device_operation_instances(instances, + device_layernorm_bwd_gamma_beta_f16_instances<2, 1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp new file mode 100644 index 0000000000..ba2966ba37 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_gamma_beta_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_layernorm_bwd_gamma_beta_f32_generic_instance<2, 1>{}); + add_device_operation_instances(instances, + device_layernorm_bwd_gamma_beta_f32_instances<2, 1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/normalization_bwd_gamma_beta_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/normalization_bwd_gamma_beta_instance_common.hpp new file mode 100644 index 0000000000..3f239527a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/normalization_bwd_gamma_beta_instance_common.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using device_layernorm_bwd_gamma_beta_f16_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize> + DeviceNormalizationBwdGammaBetaImpl, + DeviceNormalizationBwdGammaBetaImpl, + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +template +using device_layernorm_bwd_gamma_beta_f16_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +template +using device_layernorm_bwd_gamma_beta_f32_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize> + DeviceNormalizationBwdGammaBetaImpl, + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +template +using device_layernorm_bwd_gamma_beta_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +using device_groupnorm_bwd_gamma_beta_f32_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize> + DeviceNormalizationBwdGammaBetaImpl, + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +using device_groupnorm_bwd_gamma_beta_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_f16_instance.cpp index 0f8bab973e..9b58e12e1d 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Pass = ck::tensor_operation::element_wise::PassThrough; void add_device_normalization_fwd_rank_5_3_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_swish_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_swish_f16_instance.cpp index 9fbbab64e7..fe42f3dfec 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_swish_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_swish_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Swish = ck::tensor_operation::element_wise::Swish; void add_device_normalization_fwd_rank_5_3_swish_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm2d_fwd_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm2d_fwd_f16_instance.cpp index bfc2e465df..8fcab734d3 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm2d_fwd_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm2d_fwd_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Pass = ck::tensor_operation::element_wise::PassThrough; void add_device_normalization_fwd_rank_2_1_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm4d_fwd_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm4d_fwd_f16_instance.cpp index 690489bbad..132eef0146 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm4d_fwd_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm4d_fwd_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Pass = ck::tensor_operation::element_wise::PassThrough; void add_device_normalization_fwd_rank_4_3_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp index 60a55dd6e1..9486a033de 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp @@ -23,24 +23,24 @@ using device_normalization_f16_instances = // clang-format off std::tuple < // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; @@ -49,31 +49,31 @@ using device_normalization_splitk_f16_instances = // clang-format off std::tuple < // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl // clang-format on >; template using device_normalization_f16_generic_instance = std::tuple< // clang-format off - DeviceNormalizationFwdImpl + DeviceNormalizationFwdImpl // clang-format on >; diff --git a/profiler/include/profiler/profile_groupnorm_bwd_data_impl.hpp b/profiler/include/profiler/profile_groupnorm_bwd_data_impl.hpp new file mode 100644 index 0000000000..55ea08e0db --- /dev/null +++ b/profiler/include/profiler/profile_groupnorm_bwd_data_impl.hpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_groupnorm_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + // we don't need DGamma and DBeta here, just for reference class + using DGammaDataType = DXDataType; + using DBetaDataType = DXDataType; + + if(length.size() != 5) + return false; + + index_t N = length[0]; + index_t G = length[3]; + index_t C = length[4]; + + std::vector reduce_dim = {1, 2, 4}; + std::vector gammaLength = {G, C}; + + Tensor dy(length); + Tensor x(length); + Tensor gamma({G, C}); + Tensor mean({N, G}); + Tensor inv_std({N, G}); + Tensor dx(length); + + Tensor host_dx(length); + Tensor host_dgamma({G, C}); + Tensor host_dbeta({G, C}); + + std::vector strideDy = + std::vector{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector strideX = strideDy; + std::vector strideDx = strideDy; + + std::vector strideGamma = {0, 0, 0, C, 1}; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_1{}); + x.GenerateTensorValue(GeneratorTensor_1{}); + gamma.GenerateTensorValue(GeneratorTensor_1{}); + mean.GenerateTensorValue(GeneratorTensor_1{}); + inv_std.GenerateTensorValue(GeneratorTensor_1{}); + dx.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + gamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + inv_std.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + dx.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{0, 1}); + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + gamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + mean.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + inv_std.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + dx.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer(length, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + reduce_dim, + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } + + continue; + } + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) + + x.mDesc.GetElementSize() * sizeof(XDataType) + + gamma.mDesc.GetElementSize() * sizeof(GammaDataType) + + mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + dx.mDesc.GetElementSize() * sizeof(DXDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + dx_dev.FromDevice(dx.mData.data()); + bool pass = ck::utils::check_err( + dx.mData, host_dx.mData, "Error: Incorrect results", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "dy : ", dy.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_dx : ", host_dx.mData, ",") << std::endl; + LogRangeAsType(std::cout << "dx : ", dx.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s," + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_layernorm_bwd_data_impl.hpp b/profiler/include/profiler/profile_layernorm_bwd_data_impl.hpp new file mode 100644 index 0000000000..e88a06122d --- /dev/null +++ b/profiler/include/profiler/profile_layernorm_bwd_data_impl.hpp @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_layernorm_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + // we don't need DGamma and DBeta here, just for reference class + using DGammaDataType = DXDataType; + using DBetaDataType = DXDataType; + + if(length.size() != Rank || Rank < 2) + return false; + + // Assume normalize dimension except for batch (first) dimension + std::vector reduce_length{length.begin() + 1, length.end()}; + std::vector reduce_dim; + for(int i = 1; i < Rank; ++i) + reduce_dim.push_back(i); + + Tensor dy(length); + Tensor x(length); + Tensor gamma(reduce_length); + Tensor mean({length[0]}); + Tensor inv_std({length[0]}); + Tensor dx(length); + + Tensor host_dx(length); + Tensor host_dgamma(reduce_length); + Tensor host_dbeta(reduce_length); + + std::vector strideDy = + std::vector{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector strideX = strideDy; + std::vector strideDx = strideDy; + + std::vector strideGamma = strideDy; + strideGamma[0] = 0; + + std::vector strideMeanInvStd{Rank, 0}; + strideMeanInvStd[0] = 1; + + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_1{}); + x.GenerateTensorValue(GeneratorTensor_1{}); + gamma.GenerateTensorValue(GeneratorTensor_1{}); + mean.GenerateTensorValue(GeneratorTensor_1{}); + inv_std.GenerateTensorValue(GeneratorTensor_1{}); + dx.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + gamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + inv_std.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + dx.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{0, 1}); + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + gamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + mean.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + inv_std.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + dx.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + constexpr int NumReduceDim = Rank - 1; + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer(length, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + reduce_dim, + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } + + continue; + } + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) + + x.mDesc.GetElementSize() * sizeof(XDataType) + + gamma.mDesc.GetElementSize() * sizeof(GammaDataType) + + mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + dx.mDesc.GetElementSize() * sizeof(DXDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + dx_dev.FromDevice(dx.mData.data()); + bool pass = ck::utils::check_err( + dx.mData, host_dx.mData, "Error: Incorrect results", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "dy : ", dy.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_dx : ", host_dx.mData, ",") << std::endl; + LogRangeAsType(std::cout << "dx : ", dx.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s," + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 0af3107157..7674b3b4f0 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -16,7 +16,9 @@ set(PROFILER_SOURCES profile_grouped_conv_fwd.cpp profile_grouped_conv_bwd_weight.cpp profile_reduce.cpp + profile_groupnorm_bwd_data.cpp profile_groupnorm_fwd.cpp + profile_layernorm_bwd_data.cpp profile_layernorm_fwd.cpp profile_max_pool3d_fwd.cpp profile_avg_pool3d_bwd.cpp @@ -78,6 +80,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_w target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) diff --git a/profiler/src/profile_groupnorm_bwd_data.cpp b/profiler/src/profile_groupnorm_bwd_data.cpp new file mode 100644 index 0000000000..f9fea1db55 --- /dev/null +++ b/profiler/src/profile_groupnorm_bwd_data.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_groupnorm_bwd_data_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct groupnormBwdDataArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_groupnorm_bwd_data() +{ + // eg: ckProfiler groupnorm_bwd_data 1 0 2 0 1 --length 1 16 16 32 40 + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1 16 16 32 40) \n" + << std::endl; +} + +int profile_groupnorm_bwd_data(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help_groupnorm_bwd_data(); + return 0; + } + + groupnormBwdDataArgParser arg_parser; + + // short unnamed options + const ck::DataTypeEnum data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + + using F32 = float; + + if(length.size() == 5) + { + if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_groupnorm_bwd_data_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + else + { + throw std::runtime_error("length should be 5"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("groupnorm_bwd_data", + "Group Normalization", + profile_groupnorm_bwd_data); diff --git a/profiler/src/profile_groupnorm_fwd.cpp b/profiler/src/profile_groupnorm_fwd.cpp index 3ba2f751cc..9a595bf7a7 100644 --- a/profiler/src/profile_groupnorm_fwd.cpp +++ b/profiler/src/profile_groupnorm_fwd.cpp @@ -98,7 +98,7 @@ int profile_groupnorm(int argc, char* argv[]) } else if(data_type == ck::DataTypeEnum::Half) { - ck::profiler::profile_groupnorm_impl( + ck::profiler::profile_groupnorm_impl( do_verification, init_method, do_log, time_kernel, length); } else diff --git a/profiler/src/profile_layernorm_bwd_data.cpp b/profiler/src/profile_layernorm_bwd_data.cpp new file mode 100644 index 0000000000..1f364d79ba --- /dev/null +++ b/profiler/src/profile_layernorm_bwd_data.cpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_layernorm_bwd_data_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct layernormBwdDataArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_layernorm_bwd_data() +{ + // eg: ckProfiler layernorm_bwd_data 0 0 2 0 1 --length 1502 4096 + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1024 1024) \n" + << std::endl; +} + +int profile_layernorm_bwd_data(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help_layernorm_bwd_data(); + return 0; + } + + layernormBwdDataArgParser arg_parser; + + // short unnamed options + const ck::DataTypeEnum data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + + using F16 = ck::half_t; + using F32 = float; + + if(length.size() == 2) + { + constexpr int rank = 2; + + if(data_type == ck::DataTypeEnum::Half) + { + ck::profiler::profile_layernorm_bwd_data_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_layernorm_bwd_data_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("layernorm_bwd_data", + "Layer Normalization", + profile_layernorm_bwd_data); diff --git a/profiler/src/profile_layernorm_fwd.cpp b/profiler/src/profile_layernorm_fwd.cpp index 9bd66e0cb8..a261bd7418 100644 --- a/profiler/src/profile_layernorm_fwd.cpp +++ b/profiler/src/profile_layernorm_fwd.cpp @@ -104,7 +104,7 @@ int profile_layernorm(int argc, char* argv[]) if(data_type == ck::DataTypeEnum::Half) { - ck::profiler::profile_layernorm_impl( + ck::profiler::profile_layernorm_impl( do_verification, init_method, do_log, time_kernel, length); } else if(data_type == ck::DataTypeEnum::Float) @@ -125,4 +125,4 @@ int profile_layernorm(int argc, char* argv[]) return 0; } -REGISTER_PROFILER_OPERATION("layernorm", "Layer Normalization", profile_layernorm); +REGISTER_PROFILER_OPERATION("layernorm_fwd", "Layer Normalization", profile_layernorm); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b325a3a7f8..6f7e18b0e7 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -140,6 +140,7 @@ add_subdirectory(grouped_convnd_bwd_weight) add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) add_subdirectory(normalization_fwd) +add_subdirectory(normalization_bwd_data) add_subdirectory(data_type) add_subdirectory(elementwise_normalization) add_subdirectory(batchnorm) diff --git a/test/normalization_bwd_data/CMakeLists.txt b/test/normalization_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000..1b6decfed7 --- /dev/null +++ b/test/normalization_bwd_data/CMakeLists.txt @@ -0,0 +1,13 @@ +add_custom_target(test_normalization_bwd_data) +add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) + add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) +endif() + +add_gtest_executable(test_groupnorm_bwd_data_fp32 test_groupnorm_bwd_data_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) + add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) +endif() + diff --git a/test/normalization_bwd_data/test_groupnorm_bwd_data_fp32.cpp b/test/normalization_bwd_data/test_groupnorm_bwd_data_fp32.cpp new file mode 100644 index 0000000000..a7860955cd --- /dev/null +++ b/test/normalization_bwd_data/test_groupnorm_bwd_data_fp32.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_groupnorm_bwd_data_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestgroupnormBwdData : public ::testing::Test +{ + protected: + using DYDataType = std::tuple_element_t<0, Tuple>; + using XDataType = std::tuple_element_t<1, Tuple>; + using GammaDataType = std::tuple_element_t<2, Tuple>; + using MeanInvStdDataType = std::tuple_element_t<3, Tuple>; + using ComputeDataType = std::tuple_element_t<4, Tuple>; + using DXDataType = std::tuple_element_t<5, Tuple>; + + void Run() + { + // Bwd data: [N, H, W, G, C], reduce H, W, C + std::vector> lengths = {{1, 1, 1, 1, 1}, + {1, 2, 3, 4, 5}, + {256, 9, 9, 9, 9}, + {1, 64, 64, 32, 10}, + {1, 32, 32, 32, 20}, + {1, 16, 16, 32, 40}}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_groupnorm_bwd_data_impl( + true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + // DYDataType XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestgroupnormBwdData, KernelTypes); +TYPED_TEST(TestgroupnormBwdData, Test_FP32) { this->Run(); } diff --git a/test/normalization_bwd_data/test_layernorm2d_bwd_data_fp32.cpp b/test/normalization_bwd_data/test_layernorm2d_bwd_data_fp32.cpp new file mode 100644 index 0000000000..870f24d064 --- /dev/null +++ b/test/normalization_bwd_data/test_layernorm2d_bwd_data_fp32.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_layernorm_bwd_data_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestLayernorm2dBwdData : public ::testing::Test +{ + protected: + using DYDataType = std::tuple_element_t<0, Tuple>; + using XDataType = std::tuple_element_t<1, Tuple>; + using GammaDataType = std::tuple_element_t<2, Tuple>; + using MeanInvStdDataType = std::tuple_element_t<3, Tuple>; + using ComputeDataType = std::tuple_element_t<4, Tuple>; + using DXDataType = std::tuple_element_t<5, Tuple>; + + void Run() + { + // Bwd data: [N, D], reduce D + std::vector> lengths = { + {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}}; + + for(auto length : lengths) + { + bool success = + ck::profiler::profile_layernorm_bwd_data_impl(true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + // DYDataType XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestLayernorm2dBwdData, KernelTypes); +TYPED_TEST(TestLayernorm2dBwdData, Test_FP32) { this->Run(); } diff --git a/test/normalization_fwd/test_groupnorm_fwd_fp16.cpp b/test/normalization_fwd/test_groupnorm_fwd_fp16.cpp index 143c725257..c31161fb33 100644 --- a/test/normalization_fwd/test_groupnorm_fwd_fp16.cpp +++ b/test/normalization_fwd/test_groupnorm_fwd_fp16.cpp @@ -47,8 +47,8 @@ class TestGroupnorm : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType> + std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); } diff --git a/test/normalization_fwd/test_groupnorm_fwd_fp32.cpp b/test/normalization_fwd/test_groupnorm_fwd_fp32.cpp index 84a833c793..08d835ed37 100644 --- a/test/normalization_fwd/test_groupnorm_fwd_fp32.cpp +++ b/test/normalization_fwd/test_groupnorm_fwd_fp32.cpp @@ -45,7 +45,7 @@ class TestGroupnorm : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType> std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); diff --git a/test/normalization_fwd/test_layernorm2d_fwd_fp16.cpp b/test/normalization_fwd/test_layernorm2d_fwd_fp16.cpp index cc49ebe0ae..3234b2e159 100644 --- a/test/normalization_fwd/test_layernorm2d_fwd_fp16.cpp +++ b/test/normalization_fwd/test_layernorm2d_fwd_fp16.cpp @@ -41,8 +41,8 @@ class TestLayernorm2d : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType> + std::tuple>; TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST(TestLayernorm2d, Test_FP16) { this->Run(); } diff --git a/test/normalization_fwd/test_layernorm4d_fwd_fp16.cpp b/test/normalization_fwd/test_layernorm4d_fwd_fp16.cpp index a3bd388f7f..d1a7b9e3df 100644 --- a/test/normalization_fwd/test_layernorm4d_fwd_fp16.cpp +++ b/test/normalization_fwd/test_layernorm4d_fwd_fp16.cpp @@ -41,8 +41,8 @@ class TestLayernorm4d : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType> + std::tuple>; TYPED_TEST_SUITE(TestLayernorm4d, KernelTypes); TYPED_TEST(TestLayernorm4d, Test_FP16) { this->Run(); }